Commit 54d3e2f1 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/moe

parents 199f7f71 b8addae2
...@@ -70,6 +70,12 @@ void add_device_permute_scale_6d_f32_instances( ...@@ -70,6 +70,12 @@ void add_device_permute_scale_6d_f32_instances(
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 6>>>&); DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 6>>>&);
#endif #endif
#ifdef CK_ENABLE_FP8
void add_device_permute_scale_6d_f32_f8_instances(
std::vector<std::unique_ptr<
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F8>, element_wise::Scale, 6>>>&);
#endif
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
...@@ -184,6 +190,13 @@ struct DeviceOperationInstanceFactory< ...@@ -184,6 +190,13 @@ struct DeviceOperationInstanceFactory<
{ {
add_device_permute_scale_6d_f16_instances(op_ptrs); add_device_permute_scale_6d_f16_instances(op_ptrs);
} }
#endif
#ifdef CK_ENABLE_FP8
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F8>>)
{
add_device_permute_scale_6d_f32_f8_instances(op_ptrs);
}
#endif #endif
} }
return op_ptrs; return op_ptrs;
......
...@@ -10,6 +10,7 @@ namespace tensor_operation { ...@@ -10,6 +10,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -46,7 +47,7 @@ using device_permute_scale_f16_instances = ...@@ -46,7 +47,7 @@ using device_permute_scale_f16_instances =
#if 0 #if 0
// Disabled instances to improve compilation time // Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters // They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...@@ -57,7 +58,7 @@ using device_permute_scale_f16_instances = ...@@ -57,7 +58,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...@@ -97,7 +98,7 @@ using device_permute_scale_f16_instances = ...@@ -97,7 +98,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>> DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>
>; >;
template <index_t NDims, template <index_t NDims,
...@@ -131,7 +132,7 @@ using device_permute_scale_f32_instances = std::tuple< ...@@ -131,7 +132,7 @@ using device_permute_scale_f32_instances = std::tuple<
#if 0 #if 0
// Disabled instances to improve compilation time // Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters // They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...@@ -142,7 +143,7 @@ using device_permute_scale_f32_instances = std::tuple< ...@@ -142,7 +143,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...@@ -168,7 +169,7 @@ using device_permute_scale_f32_instances = std::tuple< ...@@ -168,7 +169,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
#endif #endif
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
...@@ -183,6 +184,51 @@ using device_permute_scale_f32_instances = std::tuple< ...@@ -183,6 +184,51 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>> DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>
>; >;
#ifdef CK_ENABLE_FP8
template <index_t NDims,
typename ElementwiseOp>
using device_permute_scale_f32_f8_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 128, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 256, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 64, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 128, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 64, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 32, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 256, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 64, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 32, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 128, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 64, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 32, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F8>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>
>;
#endif
// clang-format on // clang-format on
} // namespace instance } // namespace instance
......
...@@ -14,15 +14,24 @@ namespace device { ...@@ -14,15 +14,24 @@ namespace device {
namespace instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex // InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_blockwise<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 6, 6, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 6, 6, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 5, 5, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 5, 5, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 6, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 6, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 5, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 5, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 3, 3, ReduceAMax, PassThrough, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 3, 3, ReduceAMax, PassThrough, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 2, 2, ReduceAMax, PassThrough, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 2, ReduceAMax, PassThrough, PassThrough, true, false>>&);
extern template void add_device_reduce_instance_blockwise< F32, F32, F32, 1, 1, ReduceAMax, PassThrough, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 1, 1, ReduceAMax, PassThrough, PassThrough, true, false>>&);
// clang-format on // clang-format on
} // namespace instance } // namespace instance
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -272,7 +272,8 @@ check_err(const Range& out, ...@@ -272,7 +272,8 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
<< " number of errors: " << err_count << std::endl;
} }
return res; return res;
} }
......
...@@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_CONVSCALE ...@@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_CONVSCALE
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp) xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp)
add_instance_library(device_grouped_conv3d_fwd_convscale_instance ${GROUPED_CONV3D_FWD_CONVSCALE}) add_instance_library(device_grouped_conv3d_fwd_convscale_instance ${GROUPED_CONV3D_FWD_CONVSCALE})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
F8,
F8,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
CombConvScale,
F8,
F8>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwdDefault,
CombConvScale>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1P0,
CombConvScale>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
CombConvScale>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS # ONLY XDL_KERNELS
set(GROUPED_CONV3D_FWD_CONVSCALE_RELU set(GROUPED_CONV3D_FWD_CONVSCALE_RELU
xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp)
add_instance_library(device_grouped_conv3d_fwd_convscale_relu_instance ${GROUPED_CONV3D_FWD_CONVSCALE_RELU}) add_instance_library(device_grouped_conv3d_fwd_convscale_relu_instance ${GROUPED_CONV3D_FWD_CONVSCALE_RELU})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
F8,
F8,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
CombConvScaleRelu,
F8,
F8>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwdDefault,
CombConvScaleRelu>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1P0,
CombConvScaleRelu>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
CombConvScaleRelu>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -3,15 +3,13 @@ ...@@ -3,15 +3,13 @@
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu;
void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
...@@ -56,7 +54,6 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_in ...@@ -56,7 +54,6 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_in
ConvFwd1x1S1P0, ConvFwd1x1S1P0,
ConvScaleRelu>{}); ConvScaleRelu>{});
} }
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
add_instance_library(device_permute_scale_instance add_instance_library(device_permute_scale_instance
device_permute_scale_1d_fp16_instances.cpp device_permute_scale_1d_fp16_instances.cpp
device_permute_scale_2d_fp16_instances.cpp device_permute_scale_2d_fp16_instances.cpp
device_permute_scale_3d_fp16_instances.cpp device_permute_scale_3d_fp16_instances.cpp
...@@ -10,4 +10,5 @@ add_instance_library(device_permute_scale_instance ...@@ -10,4 +10,5 @@ add_instance_library(device_permute_scale_instance
device_permute_scale_3d_fp32_instances.cpp device_permute_scale_3d_fp32_instances.cpp
device_permute_scale_4d_fp32_instances.cpp device_permute_scale_4d_fp32_instances.cpp
device_permute_scale_5d_fp32_instances.cpp device_permute_scale_5d_fp32_instances.cpp
device_permute_scale_6d_fp32_instances.cpp) device_permute_scale_6d_fp32_instances.cpp
device_permute_scale_6d_fp32_fp8_instances.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Scale = element_wise::Scale;
void add_device_permute_scale_6d_f32_f8_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F8>, Scale, 6>>>&
instances)
{
#ifdef CK_ENABLE_FP8
add_device_operation_instances(instances, device_permute_scale_f32_f8_instances<6, Scale>{});
#else
ignore = instances;
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -10,15 +10,24 @@ namespace device { ...@@ -10,15 +10,24 @@ namespace device {
namespace instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex // InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
template void add_device_reduce_instance_blockwise<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&); template void add_device_reduce_instance_blockwise< F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, false>>&);
template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
template void add_device_reduce_instance_blockwise<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
template void add_device_reduce_instance_blockwise<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&); template void add_device_reduce_instance_blockwise< F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 1, ReduceAMax, UnaryAbs, PassThrough, false, true>>&);
template void add_device_reduce_instance_blockwise< F32, F32, F32, 6, 6, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 6, 6, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
template void add_device_reduce_instance_blockwise< F32, F32, F32, 5, 5, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 5, 5, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 4, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
template void add_device_reduce_instance_blockwise< F32, F32, F32, 6, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 6, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
template void add_device_reduce_instance_blockwise< F32, F32, F32, 5, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 5, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
template void add_device_reduce_instance_blockwise< F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 4, 3, ReduceAMax, UnaryAbs, PassThrough, true, false>>&);
template void add_device_reduce_instance_blockwise< F32, F32, F32, 3, 3, ReduceAMax, PassThrough, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 3, 3, ReduceAMax, PassThrough, PassThrough, true, false>>&);
template void add_device_reduce_instance_blockwise< F32, F32, F32, 2, 2, ReduceAMax, PassThrough, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 2, 2, ReduceAMax, PassThrough, PassThrough, true, false>>&);
template void add_device_reduce_instance_blockwise< F32, F32, F32, 1, 1, ReduceAMax, PassThrough, PassThrough, true, false>(std::vector<DeviceReducePtr<F32, F32, F32, 1, 1, ReduceAMax, PassThrough, PassThrough, true, false>>&);
// clang-format on // clang-format on
} // namespace instance } // namespace instance
......
...@@ -136,9 +136,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -136,9 +136,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name; std::string best_op_name;
float best_avg_time = 0; float best_avg_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
ck::index_t best_split_k = 1;
// profile device Conv instances // profile device Conv instances
bool all_pass = true; bool all_pass = true;
...@@ -167,99 +168,111 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -167,99 +168,111 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.input_left_pads_, begin(input_left_pads)); range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads)); range_copy(conv_param.input_right_pads_, begin(input_right_pads));
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
if(split_k > 0)
{
split_k_list = {split_k};
}
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k);
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// using atomic add, so need to reset input auto argument_ptr = op_ptr->MakeArgumentPointer(
wei_device_buf.SetZero(); static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
std::string op_name = op_ptr->GetTypeString(); static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
input_lengths,
auto invoker_ptr = op_ptr->MakeInvokerPointer(); input_strides,
filter_lengths,
float avg_time = weights_strides,
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k_list[split_k_id]);
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// using atomic add, so need to reset input
wei_device_buf.SetZero();
std::size_t flop = conv_param.GetFlops(); std::string op_name = op_ptr->GetTypeString();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time; auto invoker_ptr = op_ptr->MakeInvokerPointer();
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " float avg_time =
<< gb_per_sec << " GB/s, " << op_name << std::endl; invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(tflops > best_tflops) std::size_t flop = conv_param.GetFlops();
{ std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification) float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
{ float gb_per_sec = num_btype / 1.E6 / avg_time;
wei_device_buf.FromDevice(weight_device_result.mData.data());
bool pass = ck::utils::check_err(weight_device_result, weight_host_result); std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", SplitK "
<< split_k_list[split_k_id] << std::endl;
if(!pass) if(tflops > best_tflops)
{ {
std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl; best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_split_k = split_k_list[split_k_id];
} }
all_pass &= pass; if(do_verification)
if(do_log)
{ {
LogRangeAsType<float>(std::cout << "output : ", output.mData, ",") << std::endl; wei_device_buf.FromDevice(weight_device_result.mData.data());
;
LogRangeAsType<float>( bool pass = ck::utils::check_err(weight_device_result, weight_host_result);
std::cout << "weight (device): ", weight_device_result.mData, ",")
<< std::endl; if(!pass)
; {
LogRangeAsType<float>( std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl;
std::cout << "weight (host): ", weight_host_result.mData, ",") }
<< std::endl;
; all_pass &= pass;
LogRangeAsType<float>(std::cout << "input: ", input.mData, ",") << std::endl;
; if(do_log)
{
LogRangeAsType<float>(std::cout << "output : ", output.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "weight (device): ", weight_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "weight (host): ", weight_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "input: ", input.mData, ",")
<< std::endl;
}
} }
} }
} else
else {
{ std::cout << op_ptr->GetTypeString() << " does not support this problem"
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; << std::endl;
}
} }
} }
std::cout << "Best configuration parameters:" std::cout << "Best configuration parameters:"
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK "
<< best_split_k << std::endl;
return all_pass; return all_pass;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include <initializer_list> #include <initializer_list>
...@@ -81,7 +81,6 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -81,7 +81,6 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv);
ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]); ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]);
split_k = std::max(1, split_k);
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# Convert miopen driver command to ck Profiler
# Example: python3 ../script/convert_miopen_driver_to_profiler.py
# /opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 -k 64 -y 3 -x 3
# -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -m conv -g 32 -F 1 -t 1
import argparse
import subprocess
def init_const_args(args):
args.ck_profiler_cmd = '../build/bin/ckProfiler'
# use decimal values
args.init_method = 2
# don't print tensor values
args.log_value = 0
def run_ck_profiler_cmd(cmd):
print("ckProfiler command:")
print(cmd)
subprocess.run(cmd)
def parse_data_type(args):
if args.data_type == "fp32":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 0
if args.data_type == "fp16":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 1
if args.data_type == "int8":
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.data_type = 4
if args.ck_profier_op == "grouped_conv_bwd_data":
print('Not supported data type for grouped_conv_bwd_data')
exit(1)
if args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 3
if args.data_type == "bfp16":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 2
def add_conv_params_to_cmd(args, cmd):
if args.spatial_dim == 1:
cmd += [str(args.fil_w), str(args.in_w)]
cmd += [str(args.conv_stride_w), str(args.dilation_w)]
cmd += [str(args.pad_w), str(args.pad_w)]
elif args.spatial_dim == 2:
cmd += [str(args.fil_h), str(args.fil_w)]
cmd += [str(args.in_h), str(args.in_w)]
cmd += [str(args.conv_stride_h), str(args.conv_stride_w)]
cmd += [str(args.dilation_h), str(args.dilation_w)]
cmd += [str(args.pad_h), str(args.pad_w)]
cmd += [str(args.pad_h), str(args.pad_w)]
elif args.spatial_dim == 3:
cmd += [str(args.fil_d), str(args.fil_h), str(args.fil_w)]
cmd += [str(args.in_d), str(args.in_h), str(args.in_w)]
cmd += [str(args.conv_stride_d), str(args.conv_stride_h)]
cmd += [str(args.conv_stride_w)]
cmd += [str(args.dilation_d),
str(args.dilation_h),
str(args.dilation_w)]
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
else:
print('Not supported spatial dim (supported: 1, 2, 3)')
exit(1)
def run_ck_grouped_conv_fwd(args):
args.ck_profier_op = "grouped_conv_fwd"
parse_data_type(args)
# default for MIOpen NHWGC
args.layout = 1
# use int32 by default
args.index_type = 0
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
cmd += [str(args.data_type), str(args.layout), str(args.index_type)]
cmd += [str(args.verify), str(args.init_method)]
cmd += [str(args.log_value), str(args.time)]
cmd += [str(args.spatial_dim), str(args.group_count)]
cmd += [str(args.batchsize), str(args.out_channels)]
cmd += [str(args.in_channels)]
add_conv_params_to_cmd(args, cmd)
run_ck_profiler_cmd(cmd)
def run_ck_grouped_conv_bwd_data(args):
args.ck_profier_op = "grouped_conv_bwd_data"
parse_data_type(args)
# default for MIOpen NHWGC
args.layout = 1
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
cmd += [str(args.data_type), str(args.layout)]
cmd += [str(args.verify), str(args.init_method)]
cmd += [str(args.log_value), str(args.time)]
cmd += [str(args.spatial_dim), str(args.group_count)]
cmd += [str(args.batchsize), str(args.out_channels)]
cmd += [str(args.in_channels)]
add_conv_params_to_cmd(args, cmd)
run_ck_profiler_cmd(cmd)
def run_ck_grouped_conv_bwd_weight(args):
args.ck_profier_op = "grouped_conv_bwd_weight"
parse_data_type(args)
# default for MIOpen NHWGC
args.layout = 2
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
args.split_k_value = -1
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
cmd += [str(args.data_type), str(args.layout)]
cmd += [str(args.verify), str(args.init_method)]
cmd += [str(args.log_value), str(args.time)]
cmd += [str(args.spatial_dim), str(args.group_count)]
cmd += [str(args.batchsize), str(args.out_channels)]
cmd += [str(args.in_channels)]
add_conv_params_to_cmd(args, cmd)
cmd += [str(args.split_k_value)]
run_ck_profiler_cmd(cmd)
# Get name of miopen driver, remove it from unknown
def process_miopen_driver_name(args, unknown):
if "convint8" in unknown:
args.data_type = 'int8'
unknown.remove("convint8")
elif "convbfp16" in unknown:
args.data_type = 'bfp16'
unknown.remove("convbfp16")
elif "convfp16" in unknown:
args.data_type = 'fp16'
unknown.remove("convfp16")
elif "conv" in unknown:
args.data_type = 'fp32'
unknown.remove("conv")
else:
print('Not supported driver (supported: conv, convfp16, convint8,'
' convbfp16).')
exit(1)
def run_ck_profiler(args):
# MIOpen get number of channel per all groups, CK profiler get number of
# channel per group
args.in_channels = int(args.in_channels / args.group_count)
args.out_channels = int(args.out_channels / args.group_count)
if args.forw == 0 or args.forw == 1 or args.forw == 3 or args.forw == 5:
run_ck_grouped_conv_fwd(args)
if args.forw == 0 or args.forw == 2 or args.forw == 3 or args.forw == 6:
run_ck_grouped_conv_bwd_data(args)
if args.forw == 0 or args.forw == 4 or args.forw == 5 or args.forw == 6:
run_ck_grouped_conv_bwd_weight(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="converter",
description="Convert miopen driver command to ck Profiler"
"\nExample: python3 "
"../script/convert_miopen_driver_to_profiler.py "
"/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 "
"-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g "
"32 -F 1 -t 1",
)
parser.add_argument(
"-in_layout",
"-I",
default=-1,
type=int,
required=False,
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)"
)
parser.add_argument(
"-forw",
"-F",
default=0,
type=int,
required=False,
help="Flag enables fwd, bwd, wrw convolutions"
"\n0 fwd+bwd+wrw (default)"
"\n1 fwd only"
"\n2 bwd only"
"\n4 wrw only"
"\n3 fwd+bwd"
"\n5 fwd+wrw"
"\n6 bwd+wrw"
)
parser.add_argument(
"-spatial_dim",
"-_",
default=2,
type=int,
required=False,
help="convolution spatial dimension (Default-2)"
)
parser.add_argument(
"-batchsize",
"-n",
default=100,
type=int,
required=False,
help="Mini-batch size (Default=100)"
)
parser.add_argument(
"-in_channels",
"-c",
default=3,
type=int,
required=False,
help="Number of Input Channels (Default=3)"
)
parser.add_argument(
"-in_d",
"-!",
default=32,
type=int,
required=False,
help="Input Depth (Default=32)"
)
parser.add_argument(
"-in_h",
"-H",
default=32,
type=int,
required=False,
help="Input Height (Default=32)"
)
parser.add_argument(
"-in_w",
"-W",
default=32,
type=int,
required=False,
help="Input Width (Default=32)"
)
parser.add_argument(
"-out_channels",
"-k",
default=32,
type=int,
required=False,
help="Number of Output Channels (Default=32)"
)
parser.add_argument(
"-fil_d",
"-@",
default=3,
type=int,
required=False,
help="Filter Depth (Default=3)"
)
parser.add_argument(
"-fil_h",
"-y",
default=3,
type=int,
required=False,
help="Filter Height (Default=3)"
)
parser.add_argument(
"-fil_w",
"-x",
default=3,
type=int,
required=False,
help="Filter Width (Default=3)"
)
parser.add_argument(
"-conv_stride_d",
"-#",
default=1,
type=int,
required=False,
help="Convolution Stride for Depth (Default=1)"
)
parser.add_argument(
"-conv_stride_h",
"-u",
default=1,
type=int,
required=False,
help="Convolution Stride for Height (Default=1)"
)
parser.add_argument(
"-conv_stride_w",
"-v",
default=1,
type=int,
required=False,
help="Convolution Stride for Width (Default=1)"
)
parser.add_argument(
"-pad_d",
"-$",
default=1,
type=int,
required=False,
help="Zero Padding for Depth (Default=0)"
)
parser.add_argument(
"-pad_h",
"-p",
default=1,
type=int,
required=False,
help="Zero Padding for Height (Default=0)"
)
parser.add_argument(
"-pad_w",
"-q",
default=1,
type=int,
required=False,
help="Zero Padding for Width (Default=0)"
)
parser.add_argument(
"-verify",
"-V",
default=1,
type=int,
required=False,
help="Verify Each Layer (Default=1)"
)
parser.add_argument(
"-time",
"-t",
default=0,
type=int,
required=False,
help="Time Each Layer (Default=0)"
)
parser.add_argument(
"-dilation_d",
"-^",
default=1,
type=int,
required=False,
help="Dilation of Filter Depth (Default=1)"
)
parser.add_argument(
"-dilation_h",
"-l",
default=1,
type=int,
required=False,
help="Dilation of Filter Height (Default=1)"
)
parser.add_argument(
"-dilation_w",
"-j",
default=1,
type=int,
required=False,
help="Dilation of Filter Width (Default=1)"
)
parser.add_argument(
"-group_count",
"-g",
type=int,
default=1,
required=False,
help="Number of Groups (Default=1)"
)
args, unknown = parser.parse_known_args()
init_const_args(args)
process_miopen_driver_name(args, unknown)
print("Ignored args:")
print(unknown)
run_ck_profiler(args)
if (GPU_TARGETS)
if (NOT GPU_TARGETS MATCHES "gfx94")
add_definitions(-DCK_SKIP_FLAKY_F8_TEST)
set(CK_SKIP_FLAKY_F8_TEST "ON")
endif()
else()
add_definitions(-DCK_SKIP_FLAKY_F8_TEST)
set(CK_SKIP_FLAKY_F8_TEST "ON")
endif()
if (USE_BITINT_EXTENSION_INT4) if (USE_BITINT_EXTENSION_INT4)
add_gtest_executable(test_int4 test_int4.cpp) add_gtest_executable(test_int4 test_int4.cpp)
if(result EQUAL 0) if(result EQUAL 0)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
using ck::bf8_t; using ck::bf8_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr; using ck::f8_convert_sr;
using ck::half_t; using ck::half_t;
using ck::type_convert; using ck::type_convert;
...@@ -24,33 +25,36 @@ TEST(BF8, ConvertFP32Nearest) ...@@ -24,33 +25,36 @@ TEST(BF8, ConvertFP32Nearest)
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds // convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(type_convert<bf8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds // convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
#endif
// convert maximal bf8_t to float and check if equal to 57344.0 // convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(type_convert<bf8_t>(57344.0f)), abs_tol); ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_rne<bf8_t>(57344.0f)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0 // convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f, ASSERT_NEAR(57344.0f,
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to bf8_t and check if it is qNan // convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(type_convert<bf8_t>(0x80),
type_convert<bf8_t>(std::numeric_limits<float>::infinity()), f8_convert_rne<bf8_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to bf8 and back, check if holds // positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds // negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f; float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds // positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f; pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds // negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f; neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol);
} }
TEST(BF8, ConvertFP32Stochastic) TEST(BF8, ConvertFP32Stochastic)
...@@ -92,34 +96,34 @@ TEST(BF8, ConvertFP16Nearest) ...@@ -92,34 +96,34 @@ TEST(BF8, ConvertFP16Nearest)
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds // convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds // convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0 // convert maximal bf8_t to fp16 and check if equal to 57344.0
ASSERT_NEAR( ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{57344.0})), abs_tol); half_t{57344.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{57344.0})), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0 // convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0}, ASSERT_NEAR(half_t{57344.0},
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN // convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(type_convert<bf8_t>(0x80),
type_convert<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to bf8 and back, check if holds // positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939}; half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds // negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351}; half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds // positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175}; pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds // negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587}; neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol);
} }
TEST(BF8, ConvertFP16Stochastic) TEST(BF8, ConvertFP16Stochastic)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
using ck::f8_convert_rne;
using ck::f8_convert_sr; using ck::f8_convert_sr;
using ck::f8_t; using ck::f8_t;
using ck::half_t; using ck::half_t;
...@@ -24,33 +25,36 @@ TEST(FP8, ConvertFP32Nearest) ...@@ -24,33 +25,36 @@ TEST(FP8, ConvertFP32Nearest)
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds // convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(type_convert<f8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds // convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
#endif
// convert maximal f8_t to float and check if equal to 240.0 // convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR(240.0f, type_convert<float>(type_convert<f8_t>(240.0f)), abs_tol); ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_rne<f8_t>(240.0f)), abs_tol);
// convert maximal float to fp8 and back, check if clipped to 240.0 // convert maximal float to fp8 and back, check if clipped to 240.0
ASSERT_NEAR(240.0f, ASSERT_NEAR(240.0f,
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_t and check if it is qNan
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(type_convert<f8_t>(0x80),
type_convert<f8_t>(std::numeric_limits<float>::infinity()), f8_convert_rne<f8_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to fp8 and back, check if holds // positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f; float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds // negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f; float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds // positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f; pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol);
// negative subnorm float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
neg_float = -0.001953125f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol);
} }
TEST(FP8, ConvertFP32Stochastic) TEST(FP8, ConvertFP32Stochastic)
...@@ -92,33 +96,33 @@ TEST(FP8, ConvertFP16Nearest) ...@@ -92,33 +96,33 @@ TEST(FP8, ConvertFP16Nearest)
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to fp8 and back, check if holds // convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(type_convert<f8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to fp8 and back, check if holds // convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal f8_t to fp16 and check if equal to 240.0 // convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(type_convert<f8_t>(half_t{240.0})), abs_tol); ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{240.0})), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to 240.0 // convert maximal fp16 to fp8 and back, check if clipped to 240.0
ASSERT_NEAR(half_t{240.0}, ASSERT_NEAR(half_t{240.0},
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(type_convert<f8_t>(0x80),
type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to fp8 and back, check if holds // positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125}; half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds // negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625}; half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds // positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625}; pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
neg_half = half_t{-0.001953125}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol);
} }
TEST(FP8, ConvertFP16Stochastic) TEST(FP8, ConvertFP16Stochastic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment