"vscode:/vscode.git/clone" did not exist on "d7441614e856e4e8e6b903cf66a6f37a01cfd057"
Unverified Commit 08eb1769 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Allow building CK for specific data types and split off last remaining DL instances. (#830)

* properly split conv_nd_bwd_data instances

* split conv2d_fwd instance data types

* split the gemm, conv2d_fwd and batched_gemm_softamx_gemm

* split the tests by data types where possible

* filter examples by DTYPES

* split few remaining examples by DTYPES

* filter most instances by DTYPES

* add new lines at end of headers, fix grouped_gemm profiler

* fix syntax

* split the ckprofiler instances by DTYPES

* split the conv2d and quantization DL and XDL instances

* fix the splitting of conv2d DL instances

* split softmax and pool_fwd tests for fp16 and fp32 types

* fix syntax

* fix the dl_int8 quantization instances isolation
parent 22443f7a
...@@ -2,21 +2,26 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,21 +2,26 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) target_link_libraries(test_batched_gemm_fp16 PRIVATE utility)
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) endif()
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance) add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility)
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) endif()
target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance) if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) target_link_libraries(test_batched_gemm_bf16 PRIVATE utility)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility) target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance) endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility)
target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(test_batched_gemm_gemm) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) add_custom_target(test_batched_gemm_gemm)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
set(target 1) add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
set(target 1)
endif()
endif() endif()
endforeach() endforeach()
\ No newline at end of file
# TODO: Enable for gfx90a after complier fix
if(DL_KERNELS) if(DL_KERNELS)
add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp) add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp)
target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance) target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance)
endif() endif()
...@@ -2,9 +2,11 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,9 +2,11 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility)
set(target 1) target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance)
set(target 1)
endif()
endif() endif()
endforeach() endforeach()
...@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(test_batched_gemm_softmax_gemm) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) add_custom_target(test_batched_gemm_softmax_gemm)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
set(target 1) add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
set(target 1)
endif()
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -2,21 +2,25 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,21 +2,25 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(test_batched_gemm_softmax_gemm_permute) if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_custom_target(test_batched_gemm_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) endif()
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) endif()
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
add_custom_target(test_elementwise_normalization) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_custom_target(test_elementwise_normalization)
add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp) add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp)
target_link_libraries(test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance)
target_link_libraries(test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance) add_dependencies(test_elementwise_normalization test_elementwise_layernorm_fp16)
endif()
add_dependencies(test_elementwise_normalization test_elementwise_layernorm_fp16) \ No newline at end of file
...@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_custom_target(test_gemm_layernorm) add_custom_target(test_gemm_layernorm)
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp)
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
set(target 1) set(target 1)
endif()
endif() endif()
endforeach() endforeach()
add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility) add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance) target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance)
endif()
\ No newline at end of file
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -12,3 +13,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -12,3 +13,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
endif()
add_custom_target(test_normalization) if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_custom_target(test_normalization)
add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp) endif()
add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp) if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp) add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp)
add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp) add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp)
target_link_libraries(test_layernorm2d_fp32 PRIVATE utility device_normalization_instance)
target_link_libraries(test_layernorm2d_fp32 PRIVATE utility device_normalization_instance) target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance)
target_link_libraries(test_layernorm2d_fp16 PRIVATE utility device_normalization_instance) add_dependencies(test_normalization test_layernorm2d_fp32)
target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance) add_dependencies(test_normalization test_groupnorm_fp32)
target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance) endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_dependencies(test_normalization test_layernorm2d_fp32) add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp)
add_dependencies(test_normalization test_layernorm2d_fp16) add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp)
add_dependencies(test_normalization test_groupnorm_fp16) target_link_libraries(test_layernorm2d_fp16 PRIVATE utility device_normalization_instance)
add_dependencies(test_normalization test_groupnorm_fp32) target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance)
add_dependencies(test_normalization test_layernorm2d_fp16)
add_dependencies(test_normalization test_groupnorm_fp16)
endif()
...@@ -41,8 +41,12 @@ class TestAvgPool2dFwd : public ::testing::Test ...@@ -41,8 +41,12 @@ class TestAvgPool2dFwd : public ::testing::Test
} }
}; };
#ifdef __fp16__
using KernelTypes = using KernelTypes =
::testing::Types<std::tuple<F16, F16, F32, I32>, std::tuple<F32, F32, F32, I32>>; ::testing::Types<std::tuple<F16, F16, F32, I32>, std::tuple<F32, F32, F32, I32>>;
#else
using KernelTypes = ::testing::Types<std::tuple<F32, F32, F32, I32>>;
#endif
TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes); TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes);
TYPED_TEST(TestAvgPool2dFwd, Test_Pool) TYPED_TEST(TestAvgPool2dFwd, Test_Pool)
......
...@@ -40,10 +40,12 @@ class TestAvgPool3dFwd : public ::testing::Test ...@@ -40,10 +40,12 @@ class TestAvgPool3dFwd : public ::testing::Test
} }
} }
}; };
#ifdef __fp16__
using KernelTypes = using KernelTypes =
::testing::Types<std::tuple<F16, F16, F32, I32>, std::tuple<F32, F32, F32, I32>>; ::testing::Types<std::tuple<F16, F16, F32, I32>, std::tuple<F32, F32, F32, I32>>;
#else
using KernelTypes = ::testing::Types<std::tuple<F32, F32, F32, I32>>;
#endif
TYPED_TEST_SUITE(TestAvgPool3dFwd, KernelTypes); TYPED_TEST_SUITE(TestAvgPool3dFwd, KernelTypes);
TYPED_TEST(TestAvgPool3dFwd, Test_Pool) TYPED_TEST(TestAvgPool3dFwd, Test_Pool)
{ {
......
...@@ -59,10 +59,12 @@ class TestMaxPool2dFwd : public ::testing::Test ...@@ -59,10 +59,12 @@ class TestMaxPool2dFwd : public ::testing::Test
} }
} }
}; };
#ifdef __fp16__
using KernelTypes = using KernelTypes =
::testing::Types<std::tuple<F16, F16, F16, I32>, std::tuple<F32, F32, F32, I32>>; ::testing::Types<std::tuple<F16, F16, F16, I32>, std::tuple<F32, F32, F32, I32>>;
#else
using KernelTypes = ::testing::Types<std::tuple<F32, F32, F32, I32>>;
#endif
TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes); TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes);
TYPED_TEST(TestMaxPool2dFwd, Test_Pool) TYPED_TEST(TestMaxPool2dFwd, Test_Pool)
{ {
......
...@@ -60,8 +60,12 @@ class TestMaxPool3dFwd : public ::testing::Test ...@@ -60,8 +60,12 @@ class TestMaxPool3dFwd : public ::testing::Test
} }
}; };
#ifdef __fp16__
using KernelTypes = using KernelTypes =
::testing::Types<std::tuple<F16, F16, F16, I32>, std::tuple<F32, F32, F32, I32>>; ::testing::Types<std::tuple<F16, F16, F32, I32>, std::tuple<F32, F32, F32, I32>>;
#else
using KernelTypes = ::testing::Types<std::tuple<F32, F32, F32, I32>>;
#endif
TYPED_TEST_SUITE(TestMaxPool3dFwd, KernelTypes); TYPED_TEST_SUITE(TestMaxPool3dFwd, KernelTypes);
TYPED_TEST(TestMaxPool3dFwd, Test_Pool) TYPED_TEST(TestMaxPool3dFwd, Test_Pool)
......
...@@ -10,8 +10,9 @@ ...@@ -10,8 +10,9 @@
template <ck::index_t N> template <ck::index_t N>
using I = ck::Number<N>; using I = ck::Number<N>;
#ifdef __fp16__
using F16 = ck::half_t; using F16 = ck::half_t;
#endif
using F32 = float; using F32 = float;
template <typename Tuple> template <typename Tuple>
...@@ -22,7 +23,9 @@ class TestSoftmax : public ck::TestSoftmax<Tuple> ...@@ -22,7 +23,9 @@ class TestSoftmax : public ck::TestSoftmax<Tuple>
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// InDataType, AccDataType, OutDataType, Rank // InDataType, AccDataType, OutDataType, Rank
#ifdef __fp16__
std::tuple< F16, F32, F16, I<3>>, std::tuple< F16, F32, F16, I<3>>,
#endif
std::tuple< F32, F32, F32, I<3>> std::tuple< F32, F32, F32, I<3>>
>; >;
// clang-format on // clang-format on
......
...@@ -10,8 +10,9 @@ ...@@ -10,8 +10,9 @@
template <ck::index_t N> template <ck::index_t N>
using I = ck::Number<N>; using I = ck::Number<N>;
#ifdef __fp16__
using F16 = ck::half_t; using F16 = ck::half_t;
#endif
using F32 = float; using F32 = float;
template <typename Tuple> template <typename Tuple>
...@@ -22,7 +23,9 @@ class TestSoftmax : public ck::TestSoftmax<Tuple> ...@@ -22,7 +23,9 @@ class TestSoftmax : public ck::TestSoftmax<Tuple>
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// InDataType, AccDataType, OutDataType, Rank // InDataType, AccDataType, OutDataType, Rank
#ifdef __fp16__
std::tuple< F16, F32, F16, I<4>>, std::tuple< F16, F32, F16, I<4>>,
#endif
std::tuple< F32, F32, F32, I<4>> std::tuple< F32, F32, F32, I<4>>
>; >;
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment