Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5dee1b64
Commit
5dee1b64
authored
Dec 09, 2024
by
Ville Pietilä
Browse files
Merge remote-tracking branch 'origin/develop' into vpietila/ggemm-profiling
parents
870c3a76
355893cd
Changes
108
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
688 additions
and
104 deletions
+688
-104
test/data_type/test_custom_type.cpp
test/data_type/test_custom_type.cpp
+214
-20
test/data_type/test_fp8_fnuz.cpp
test/data_type/test_fp8_fnuz.cpp
+83
-66
test/data_type/test_fp8_ocp.cpp
test/data_type/test_fp8_ocp.cpp
+250
-0
test/grouped_convnd_bwd_data/CMakeLists.txt
test/grouped_convnd_bwd_data/CMakeLists.txt
+6
-2
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp
...ped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp
+108
-0
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp
...uped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp
+25
-14
test/pool/test_avg_pool2d_fwd.cpp
test/pool/test_avg_pool2d_fwd.cpp
+1
-1
test/pool/test_max_pool2d_fwd.cpp
test/pool/test_max_pool2d_fwd.cpp
+1
-1
No files found.
test/data_type/test_custom_type.cpp
View file @
5dee1b64
...
...
@@ -51,8 +51,11 @@ TEST(Custom_bool, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{})
=
custom_bool_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_bool_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_bool_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_bool_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -129,8 +132,11 @@ TEST(Custom_int8, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{})
=
custom_int8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_int8_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_int8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_int8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -207,8 +213,11 @@ TEST(Custom_uint8, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{})
=
custom_uint8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_uint8_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_uint8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_uint8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -287,8 +296,11 @@ TEST(Custom_f8, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{})
=
custom_f8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_f8_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_f8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_f8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -369,8 +381,11 @@ TEST(Custom_bf8, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{})
=
custom_bf8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_bf8_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_bf8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_bf8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -450,8 +465,11 @@ TEST(Custom_half, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{})
=
custom_half_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_half_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_half_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_half_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -533,8 +551,11 @@ TEST(Custom_bhalf, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{})
=
custom_bhalf_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_bhalf_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_bhalf_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_bhalf_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -615,8 +636,11 @@ TEST(Custom_float, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{})
=
custom_float_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_float_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_float_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_float_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -693,8 +717,11 @@ TEST(Custom_double, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{})
=
custom_double_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_double_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_double_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_double_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -813,8 +840,11 @@ TEST(Complex_half, TestAsType)
right_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{})
=
complex_half_t
{
test_vec
.
at
(
num_elem
*
i
),
test_vec
.
at
(
num_elem
*
i
+
1
)};
});
// copy the vector
vector_type
<
complex_half_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
complex_half_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
complex_half_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
real
,
...
...
@@ -872,3 +902,167 @@ TEST(Complex_half, TestAsTypeReshape)
test_vec
.
at
(
num_elem
*
i
+
1
));
});
}
#if CK_USE_OCP_FP8
TEST
(
FP8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
f8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
FP8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
vector_type
<
f8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
f8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
ck
::
non_native_vector_base
<
ck
::
f8_ocp_t
,
2
>
nnvb_f8x2
(
ck
::
type_convert
<
f8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_f8x2
.
template
AsType
<
f8_t
>()(
Number
<
0
>
{}),
ck
::
type_convert
<
f8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_f8x2
.
template
AsType
<
f8_t
>()(
Number
<
1
>
{}),
ck
::
type_convert
<
f8_t
>
(
-
10.0
f
));
}
TEST
(
FP8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
f8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
f8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
}
TEST
(
BF8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
bf8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
BF8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
vector_type
<
bf8_t
,
size
>
left_vec
{
right_vec
};
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
bf8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
ck
::
non_native_vector_base
<
bf8_t
,
2
>
nnvb_bf8x2
(
ck
::
type_convert
<
bf8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_bf8x2
.
template
AsType
<
bf8_t
>()(
Number
<
0
>
{}),
ck
::
type_convert
<
bf8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_bf8x2
.
template
AsType
<
bf8_t
>()(
Number
<
1
>
{}),
ck
::
type_convert
<
bf8_t
>
(
-
10.0
f
));
}
TEST
(
BF8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
bf8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
bf8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
}
#endif
test/data_type/test_fp8.cpp
→
test/data_type/test_fp8
_fnuz
.cpp
View file @
5dee1b64
...
...
@@ -7,154 +7,171 @@
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_t
;
using
ck
::
f8_
fnuz_
t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
FP8
,
NumericLimits
)
TEST
(
FP8
FNUZ
,
NumericLimits
)
{
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Min
(),
type_convert
<
f8_t
>
(
0x08
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Max
(),
type_convert
<
f8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Lowest
(),
type_convert
<
f8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
QuietNaN
(),
type_convert
<
f8_t
>
(
0x80
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Min
(),
type_convert
<
f8_
fnuz_
t
>
(
0x08
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Max
(),
type_convert
<
f8_
fnuz_
t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Lowest
(),
type_convert
<
f8_
fnuz_
t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
QuietNaN
(),
type_convert
<
f8_
fnuz_
t
>
(
0x80
));
}
TEST
(
FP8
,
ConvertFP32Nearest
)
TEST
(
FP8
FNUZ
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
#endif
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
240.0
f
)),
abs_tol
);
// convert maximal float to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
max_f8_t_float
)),
abs_tol
);
// XXX: FNUZ f8_convert_rne behavior is inconsistent.
// Clipping large values to fp8 max (saturation to finite) contradicts converting inf float to
// fp8 qNAN (no saturation).
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to f8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
// convert inf float to f8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
TEST
(
FP8
,
ConvertFP32Stochastic
)
TEST
(
FP8
FNUZ
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
240.0
f
)),
abs_tol
);
// convert maximal float to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
max_f8_t_float
)),
abs_tol
);
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to f8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
// convert inf float to f8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
TEST
(
FP8
,
ConvertFP16Nearest
)
TEST
(
FP8
FNUZ
,
ConvertFP16Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-3
;
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
const
auto
max_f8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
max_f8_t_half
)),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
// convert QuietNaN fp16 to f8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.001953125
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
TEST
(
FP8
,
ConvertFP16Stochastic
)
TEST
(
FP8
FNUZ
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-3
;
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
const
auto
max_f8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
max_f8_t_half
)),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
// convert QuietNaN fp16 to f8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.001953125
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
test/data_type/test_fp8_ocp.cpp
0 → 100644
View file @
5dee1b64
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_ocp_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
FP8OCP
,
NumericLimits
)
{
// constants given for OCP FP8
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Min
(),
type_convert
<
f8_ocp_t
>
(
0x08
));
// 0b00001000 = 2^-6
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
type_convert
<
f8_ocp_t
>
(
0x7E
));
// 0b01111110 = 448
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Lowest
(),
type_convert
<
f8_ocp_t
>
(
0xFE
));
// 0b11111110 = -448
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
QuietNaN
().
data
,
type_convert
<
f8_ocp_t
>
(
0x7F
).
data
);
// 0b01111111
EXPECT_FALSE
(
ck
::
NumericLimits
<
f8_ocp_t
>::
QuietNaN
()
==
ck
::
NumericLimits
<
f8_ocp_t
>::
QuietNaN
());
}
TEST
(
FP8OCP
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
max_f8_t_float
)),
0.0
f
);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
//-2^-6
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_float
)),
0.0
f
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
//-2^-9
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_float
)),
0.0
f
);
// smaller than min subnorm fp8 value to fp8 must be zero
auto
less_than_min_subnorm
=
0.0009765625
f
;
// 2^-10
ASSERT_EQ
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
}
TEST
(
FP8OCP
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
max_f8_t_float
)),
0.0
f
);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
//-2^-6
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_float
)),
0.0
f
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
constexpr
auto
min_subnorm_fp8
=
-
0.001953125
f
;
//-2^-9
ASSERT_NEAR
(
min_subnorm_fp8
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
min_subnorm_fp8
)),
0.0
f
);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto
less_than_min_subnorm
=
0.0009765625
f
;
// 2^-10
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
less_than_min_subnorm
)),
0.001953125
f
);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
}
TEST
(
FP8OCP
,
ConvertFP16Nearest
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_tol
);
const
auto
max_f8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
max_f8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_rne
<
f8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive norm half_t value to fp8 and back, check if holds
half_t
pos_half_t
{
0.017578125
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
half_t
neg_half_t
{
-
0.015625
f
};
//-2^-6
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t
=
half_t
{
0.00390625
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t
=
half_t
{
-
0.001953125
f
};
//-2^-9
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// smaller than min subnorm fp8 value to fp8 must be zero
auto
less_than_min_subnorm
=
half_t
{
0.0009765625
f
};
// 2^-10
ASSERT_EQ
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_f8_is_nan
(
f8_nan
.
data
));
}
TEST
(
FP8OCP
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
constexpr
auto
min_subnorm_fp8
=
0.001953125
f
;
// 2^-9
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t (6.103515625e-05) to fp8 and back
// alternates between 0 and 2^-9 (0.001953125)
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
min_subnorm_fp8
));
const
auto
max_f8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
max_f8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_sr
<
f8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive norm half_t value to fp8 and back, check if holds
half_t
pos_half_t
{
0.017578125
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
half_t
neg_half_t
{
-
0.015625
f
};
//-2^-6
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t
=
half_t
{
0.00390625
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t
=
half_t
{
-
min_subnorm_fp8
};
//-2^-9
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto
less_than_min_subnorm
=
half_t
{
0.0009765625
f
};
// 2^-10
ASSERT_NEAR
(
type_convert
<
float
>
(
half_t_zero
),
type_convert
<
float
>
(
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
less_than_min_subnorm
))),
min_subnorm_fp8
);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_f8_is_nan
(
f8_nan
.
data
));
}
test/grouped_convnd_bwd_data/CMakeLists.txt
View file @
5dee1b64
add_gtest_executable
(
test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl
_wmma
.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_data
_xdl
test_grouped_convnd_bwd_data_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp
)
if
(
result EQUAL 0
)
...
...
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp
0 → 100644
View file @
5dee1b64
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
template
<
typename
Tuple
>
class
TestGroupedConvndBwdDataWmma
:
public
::
testing
::
Test
{
protected:
using
DataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
OutLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
WeiLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
InLayout
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
template
<
ck
::
index_t
NDimSpatial
>
void
Run
()
{
EXPECT_FALSE
(
conv_params
.
empty
());
bool
pass
=
true
;
for
(
auto
&
param
:
conv_params
)
{
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_data_impl
<
NDimSpatial
,
OutLayout
,
WeiLayout
,
InLayout
,
DataType
,
DataType
,
DataType
>
(
true
,
// do_verification
1
,
// init_method: integer value
false
,
// do_log
false
,
// time_kernel
param
);
}
EXPECT_TRUE
(
pass
);
}
};
using
namespace
ck
::
tensor_layout
::
convolution
;
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
ck
::
half_t
,
GNHWK
,
GKYXC
,
GNHWC
>
,
std
::
tuple
<
int8_t
,
GNHWK
,
GKYXC
,
GNHWC
>
,
std
::
tuple
<
ck
::
half_t
,
NHWGK
,
GKYXC
,
NHWGC
>
,
std
::
tuple
<
int8_t
,
NHWGK
,
GKYXC
,
NHWGC
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
ck
::
half_t
,
GNDHWK
,
GKZYXC
,
GNDHWC
>
,
std
::
tuple
<
int8_t
,
GNDHWK
,
GKZYXC
,
GNDHWC
>
,
std
::
tuple
<
ck
::
half_t
,
NDHWGK
,
GKZYXC
,
NDHWGC
>
,
std
::
tuple
<
int8_t
,
NDHWGK
,
GKZYXC
,
NDHWGC
>>
;
template
<
typename
Tuple
>
class
TestGroupedConvndBwdDataWmma2d
:
public
TestGroupedConvndBwdDataWmma
<
Tuple
>
{
};
template
<
typename
Tuple
>
class
TestGroupedConvndBwdDataWmma3d
:
public
TestGroupedConvndBwdDataWmma
<
Tuple
>
{
};
TYPED_TEST_SUITE
(
TestGroupedConvndBwdDataWmma2d
,
KernelTypes2d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdDataWmma3d
,
KernelTypes3d
);
TYPED_TEST
(
TestGroupedConvndBwdDataWmma2d
,
Test2D
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
{
2
,
2
,
4
,
192
,
192
,
{
3
,
3
},
{
28
,
28
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
template
Run
<
2
>();
}
TYPED_TEST
(
TestGroupedConvndBwdDataWmma3d
,
Test3D
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
{
3
,
2
,
16
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
2
,
2
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
2
,
32
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
32
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
template
Run
<
3
>();
}
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl
_wmma
.cpp
→
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp
View file @
5dee1b64
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
...
...
@@ -12,7 +12,7 @@
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
template
<
typename
Tuple
>
class
TestGroupedConvndBwdData
:
public
::
testing
::
Test
class
TestGroupedConvndBwdData
Xdl
:
public
::
testing
::
Test
{
protected:
using
DataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
...
...
@@ -51,35 +51,31 @@ using namespace ck::tensor_layout::convolution;
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNHWK
,
GKYXC
,
GNHWC
>
,
std
::
tuple
<
ck
::
half_t
,
GNHWK
,
GKYXC
,
GNHWC
>
,
std
::
tuple
<
ck
::
bhalf_t
,
GNHWK
,
GKYXC
,
GNHWC
>
,
std
::
tuple
<
int8_t
,
GNHWK
,
GKYXC
,
GNHWC
>
,
std
::
tuple
<
float
,
NHWGK
,
GKYXC
,
NHWGC
>
,
std
::
tuple
<
ck
::
half_t
,
NHWGK
,
GKYXC
,
NHWGC
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NHWGK
,
GKYXC
,
NHWGC
>
,
std
::
tuple
<
int8_t
,
NHWGK
,
GKYXC
,
NHWGC
>>
;
std
::
tuple
<
ck
::
bhalf_t
,
NHWGK
,
GKYXC
,
NHWGC
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWK
,
GKZYXC
,
GNDHWC
>
,
std
::
tuple
<
ck
::
half_t
,
GNDHWK
,
GKZYXC
,
GNDHWC
>
,
std
::
tuple
<
ck
::
bhalf_t
,
GNDHWK
,
GKZYXC
,
GNDHWC
>
,
std
::
tuple
<
int8_t
,
GNDHWK
,
GKZYXC
,
GNDHWC
>
,
std
::
tuple
<
float
,
NDHWGK
,
GKZYXC
,
NDHWGC
>
,
std
::
tuple
<
ck
::
half_t
,
NDHWGK
,
GKZYXC
,
NDHWGC
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NDHWGK
,
GKZYXC
,
NDHWGC
>
,
std
::
tuple
<
int8_t
,
NDHWGK
,
GKZYXC
,
NDHWGC
>>
;
std
::
tuple
<
ck
::
bhalf_t
,
NDHWGK
,
GKZYXC
,
NDHWGC
>>
;
template
<
typename
Tuple
>
class
TestGroupedConvndBwdData2d
:
public
TestGroupedConvndBwdData
<
Tuple
>
class
TestGroupedConvndBwdData
Xdl
2d
:
public
TestGroupedConvndBwdData
Xdl
<
Tuple
>
{
};
template
<
typename
Tuple
>
class
TestGroupedConvndBwdData3d
:
public
TestGroupedConvndBwdData
<
Tuple
>
class
TestGroupedConvndBwdData
Xdl
3d
:
public
TestGroupedConvndBwdData
Xdl
<
Tuple
>
{
};
TYPED_TEST_SUITE
(
TestGroupedConvndBwdData2d
,
KernelTypes2d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdData3d
,
KernelTypes3d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdData
Xdl
2d
,
KernelTypes2d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdData
Xdl
3d
,
KernelTypes3d
);
TYPED_TEST
(
TestGroupedConvndBwdData2d
,
Test2D
)
TYPED_TEST
(
TestGroupedConvndBwdData
Xdl
2d
,
Test2D
)
{
this
->
conv_params
.
clear
();
...
...
@@ -94,10 +90,13 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
// SplitN case
this
->
conv_params
.
push_back
(
{
2
,
1
,
128
,
4
,
192
,
{
2
,
2
},
{
224
,
224
},
{
224
,
224
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
template
Run
<
2
>();
}
TYPED_TEST
(
TestGroupedConvndBwdData3d
,
Test3D
)
TYPED_TEST
(
TestGroupedConvndBwdData
Xdl
3d
,
Test3D
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
...
...
@@ -112,5 +111,17 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
// SplitN case
this
->
conv_params
.
push_back
({
3
,
1
,
128
,
4
,
192
,
{
2
,
2
,
2
},
{
2
,
224
,
224
},
{
1
,
224
,
224
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
template
Run
<
3
>();
}
test/pool/test_avg_pool2d_fwd.cpp
View file @
5dee1b64
...
...
@@ -138,7 +138,7 @@ TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
TYPED_TEST_SUITE
(
AvgPool2D_I8
,
AvgPool2D_I8_Types
);
TYPED_TEST_SUITE
(
AvgPool2D_F8
,
AvgPool2D_F8_Types
);
TYPED_TEST
(
AvgPool2D_F32
,
AvgPool2D_
I8
_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_F32
,
AvgPool2D_
F32
_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_F16
,
AvgPool2D_F16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_BF16
,
AvgPool2D_BF16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_I8
,
AvgPool2D_I8_Test
)
{
this
->
Run
();
}
...
...
test/pool/test_max_pool2d_fwd.cpp
View file @
5dee1b64
...
...
@@ -143,7 +143,7 @@ TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
TYPED_TEST_SUITE
(
MaxPool2D_I8
,
MaxPool2D_I8_Types
);
TYPED_TEST_SUITE
(
MaxPool2D_F8
,
MaxPool2D_F8_Types
);
TYPED_TEST
(
MaxPool2D_F32
,
MaxPool2D_
I8
_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_F32
,
MaxPool2D_
F32
_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_F16
,
MaxPool2D_F16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_BF16
,
MaxPool2D_BF16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_I8
,
MaxPool2D_I8_Test
)
{
this
->
Run
();
}
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment