Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
755ace59
Unverified
Commit
755ace59
authored
Sep 19, 2023
by
zjing14
Committed by
GitHub
Sep 19, 2023
Browse files
Merge branch 'develop' into add_int8_wmma_example_instance
parents
e24f37fb
63cd4592
Changes
108
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
292 additions
and
48 deletions
+292
-48
profiler/include/profiler/profile_gemm_multiply_add_impl.hpp
profiler/include/profiler/profile_gemm_multiply_add_impl.hpp
+2
-2
profiler/include/profiler/profile_gemm_splitk_impl.hpp
profiler/include/profiler/profile_gemm_splitk_impl.hpp
+4
-0
profiler/src/profile_gemm_multiply_add.cpp
profiler/src/profile_gemm_multiply_add.cpp
+5
-1
profiler/src/profile_gemm_splitk.cpp
profiler/src/profile_gemm_splitk.cpp
+5
-1
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+9
-2
test/data_type/bf8.cpp
test/data_type/bf8.cpp
+158
-0
test/data_type/f8.cpp
test/data_type/f8.cpp
+57
-24
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+52
-18
No files found.
profiler/include/profiler/profile_gemm_multiply_add_impl.hpp
View file @
755ace59
...
...
@@ -83,8 +83,8 @@ bool profile_gemm_multiply_add_impl(int do_verification,
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
1
,
1
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.
5
,
0.
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
0.2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.
1
,
0.
1
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
0.0
,
1.0
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
}
...
...
profiler/include/profiler/profile_gemm_splitk_impl.hpp
View file @
755ace59
...
...
@@ -214,6 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification,
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
CDataType
,
f8_t
>
)
...
...
@@ -226,8 +227,11 @@ bool profile_gemm_splitk_impl(int do_verification,
}
else
{
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#if defined CK_ENABLE_FP8
}
#endif
if
(
tflops
>
best_tflops
)
{
...
...
profiler/src/profile_gemm_multiply_add.cpp
View file @
755ace59
...
...
@@ -59,9 +59,11 @@ int profile_gemm_multiply_add(int argc, char* argv[])
const
int
StrideD1
=
std
::
stoi
(
argv
[
14
]);
const
int
StrideE
=
std
::
stoi
(
argv
[
15
]);
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
#if defined CK_ENABLE_FP8
using
F8
=
ck
::
f8_t
;
#endif
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -132,6 +134,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Col
{},
Row
{},
Row
{},
Row
{});
}
#if defined CK_ENABLE_FP8
else
if
(
data_type
==
MatrixDataType
::
F16_F8_F32_F32_F16
&&
layout
==
MatrixLayout
::
MK_KN_MN_MN_MN
)
{
...
...
@@ -142,6 +145,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
{
return
profile
(
F16
{},
F8
{},
F32
{},
F32
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
Row
{},
Row
{});
}
#endif
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
profiler/src/profile_gemm_splitk.cpp
View file @
755ace59
...
...
@@ -67,7 +67,9 @@ int profile_gemm_splitk(int argc, char* argv[])
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F8
=
ck
::
f8_t
;
#if defined CK_ENABLE_FP8
using
F8
=
ck
::
f8_t
;
#endif
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -146,6 +148,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
}
#if defined CK_ENABLE_FP8
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
...
...
@@ -178,6 +181,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
}
#endif
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
test/data_type/CMakeLists.txt
View file @
755ace59
...
...
@@ -3,5 +3,12 @@ if (USE_BITINT_EXTENSION_INT4)
target_link_libraries
(
test_int4 PRIVATE utility
)
endif
()
add_gtest_executable
(
test_fp8 fp8.cpp
)
target_link_libraries
(
test_fp8 PRIVATE utility
)
if
(
DTYPES MATCHES
"fp8"
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_f8 f8.cpp
)
target_link_libraries
(
test_f8 PRIVATE utility
)
endif
()
if
(
DTYPES MATCHES
"bf8"
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_bf8 bf8.cpp
)
target_link_libraries
(
test_bf8 PRIVATE utility
)
endif
()
test/data_type/bf8.cpp
0 → 100644
View file @
755ace59
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_t
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
BF8
,
NumericLimits
)
{
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Min
(),
type_convert
<
bf8_t
>
(
0x04
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Max
(),
type_convert
<
bf8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Lowest
(),
type_convert
<
bf8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
QuietNaN
(),
type_convert
<
bf8_t
>
(
0x80
));
}
TEST
(
BF8
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
type_convert
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
type_convert
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
type_convert
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
type_convert
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
type_convert
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
type_convert
<
bf8_t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
type_convert
<
bf8_t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
type_convert
<
bf8_t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
type_convert
<
bf8_t
>
(
neg_float
)),
abs_tol
);
}
TEST
(
BF8
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
}
TEST
(
BF8
,
ConvertFP16Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
type_convert
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
type_convert
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
},
type_convert
<
half_t
>
(
type_convert
<
bf8_t
>
(
half_t
{
57344.0
})),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
},
type_convert
<
half_t
>
(
type_convert
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
type_convert
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
type_convert
<
bf8_t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
type_convert
<
bf8_t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
type_convert
<
bf8_t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
type_convert
<
bf8_t
>
(
neg_half
)),
abs_tol
);
}
TEST
(
BF8
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
half_t
{
57344.0
})),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
}
test/data_type/f
p
8.cpp
→
test/data_type/f8.cpp
View file @
755ace59
...
...
@@ -12,10 +12,11 @@ using ck::type_convert;
TEST
(
FP8
,
NumericLimits
)
{
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Min
(),
0x08
);
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Max
(),
0x77
);
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Lowest
(),
0xF7
);
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
QuietNaN
(),
0x80
);
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Min
(),
type_convert
<
f8_t
>
(
0x08
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Max
(),
type_convert
<
f8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Lowest
(),
type_convert
<
f8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
QuietNaN
(),
type_convert
<
f8_t
>
(
0x80
));
}
TEST
(
FP8
,
ConvertFP32Nearest
)
...
...
@@ -35,12 +36,20 @@ TEST(FP8, ConvertFP32Nearest)
type_convert
<
float
>
(
type_convert
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to f8_t and check if it is qNan
ASSERT_NEAR
(
0x80
,
type_convert
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive float value to fp8 and back, check if holds
float
pos_float
=
0.0078125
f
;
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
type_convert
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
type_convert
<
f8_t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
type_convert
<
f8_t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
type_convert
<
f8_t
>
(
pos_float
)),
abs_tol
);
// negative float value to fp8 and back, check if holds
float
neg_float
=
-
0.0
156
25
0
f
;
// negative
subnorm
float value to fp8 and back, check if holds
neg_float
=
-
0.0
019531
25
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
type_convert
<
f8_t
>
(
neg_float
)),
abs_tol
);
}
...
...
@@ -61,12 +70,20 @@ TEST(FP8, ConvertFP32Stochastic)
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to f8_t and check if it is qNan
ASSERT_NEAR
(
0x80
,
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive float value to fp8 and back, check if holds
float
pos_float
=
0.0078125
f
;
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
// negative float value to fp8 and back, check if holds
float
neg_float
=
-
0.0
156
25
0
f
;
// negative
subnorm
float value to fp8 and back, check if holds
neg_float
=
-
0.0
019531
25
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
}
...
...
@@ -87,12 +104,20 @@ TEST(FP8, ConvertFP16Nearest)
type_convert
<
half_t
>
(
type_convert
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR
(
0x80
,
type_convert
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0078125
};
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
type_convert
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
type_convert
<
f8_t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
type_convert
<
f8_t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
type_convert
<
f8_t
>
(
pos_half
)),
abs_tol
);
// negative fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0
156
25
0
};
// negative
subnorm
fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.0
019531
25
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
type_convert
<
f8_t
>
(
neg_half
)),
abs_tol
);
}
...
...
@@ -113,11 +138,19 @@ TEST(FP8, ConvertFP16Stochastic)
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR
(
0x80
,
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0078125
};
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
// negative fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0
156
25
0
};
// negative
subnorm
fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.0
019531
25
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
}
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
755ace59
...
...
@@ -14,6 +14,8 @@
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
using
namespace
ck
::
tensor_layout
::
convolution
;
template
<
typename
Tuple
>
class
TestGroupedConvndBwdWeight
:
public
::
testing
::
Test
{
...
...
@@ -27,28 +29,59 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
using
NDimSpatial
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
ck
::
index_t
split_k
{
2
};
std
::
vector
<
ck
::
index_t
>
split_ks
{
1
,
2
};
bool
skip_case
(
const
ck
::
utils
::
conv
::
ConvParam
&
params
,
const
ck
::
index_t
split_k
)
{
// Odd K or C values are supported only by DL kernel (only applies to fp16)
// DL kernel currently supports only `split_k=1`
if
constexpr
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
)
{
if
(
split_k
!=
1
&&
(
params
.
K_
%
2
!=
0
||
params
.
C_
%
2
!=
0
))
{
return
true
;
}
}
// 1d NWGC is only supported by DL kernel
// DL kernel is only supported for split_k=1
if
constexpr
(
std
::
is_same_v
<
InLayout
,
NWGC
>
&&
std
::
is_same_v
<
OutLayout
,
NWGK
>
)
{
if
(
split_k
!=
1
)
{
return
true
;
}
}
return
false
;
}
void
Run
()
{
EXPECT_FALSE
(
conv_params
.
empty
());
bool
pass
=
true
;
for
(
auto
&
param
:
conv_param
s
)
for
(
auto
split_k
:
split_k
s
)
{
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_weight_impl
<
NDimSpatial
{},
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
>
(
true
,
// do_verification
1
,
// init_method: integer value
false
,
// do_log
false
,
// time_kernel
param
,
split_k
);
for
(
auto
&
param
:
conv_params
)
{
if
(
!
skip_case
(
param
,
split_k
))
{
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_weight_impl
<
NDimSpatial
{},
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
>
(
true
,
// do_verification
1
,
// init_method: integer value
false
,
// do_log
false
,
// time_kernel
param
,
split_k
);
}
}
}
EXPECT_TRUE
(
pass
);
}
...
...
@@ -69,12 +102,13 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
{
};
using
namespace
ck
::
tensor_layout
::
convolution
;
using
KernelTypes1d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNWC
,
GKXC
,
GNWK
,
ck
::
Number
<
1
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNWC
,
GKXC
,
GNWK
,
ck
::
Number
<
1
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNWC
,
GKXC
,
GNWK
,
ck
::
Number
<
1
>>>
;
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNWC
,
GKXC
,
GNWK
,
ck
::
Number
<
1
>>
,
std
::
tuple
<
float
,
float
,
float
,
NWGC
,
GKXC
,
NWGK
,
ck
::
Number
<
1
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NWGC
,
GKXC
,
NWGK
,
ck
::
Number
<
1
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NWGC
,
GKXC
,
NWGK
,
ck
::
Number
<
1
>>>
;
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment