Unverified Commit 74284dd2 authored by rocking's avatar rocking Committed by GitHub
Browse files

Merge branch 'develop' into avgpool_bwd

parents 01aeb901 50643dd5
...@@ -144,7 +144,7 @@ GB/s: 127.947 ...@@ -144,7 +144,7 @@ GB/s: 127.947
## Profile grouped convolution backward weight kernels ## Profile grouped convolution backward weight kernels
```bash ```bash
# arg1: tensor operation (grouped_conv_bwd_data: Grouped Convolution Backward Data) # arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight)
# arg2: data type (0: Input fp32, Weight fp32, Output fp32 # arg2: data type (0: Input fp32, Weight fp32, Output fp32
# 1: Input fp16, Weight fp16, Output fp16 # 1: Input fp16, Weight fp16, Output fp16
# 2: Input bf16, Weight fp32, Output bf16) # 2: Input bf16, Weight fp32, Output bf16)
...@@ -167,7 +167,7 @@ GB/s: 127.947 ...@@ -167,7 +167,7 @@ GB/s: 127.947
# SplitK # SplitK
################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK
./bin/ckProfiler grouped_conv_bwd_data 1 0 1 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1 ./bin/ckProfiler grouped_conv_bwd_weight 1 0 1 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1
``` ```
......
...@@ -81,4 +81,5 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_in ...@@ -81,4 +81,5 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_in
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
...@@ -70,8 +70,10 @@ int profile_batched_gemm_multi_d(int argc, char* argv[]) ...@@ -70,8 +70,10 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
const int BatchCount = std::stoi(argv[17]); const int BatchCount = std::stoi(argv[17]);
using F16 = ck::half_t; using F16 = ck::half_t;
#ifdef __int8__
using INT8 = int8_t; using INT8 = int8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -163,6 +165,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[]) ...@@ -163,6 +165,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
{ {
return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{});
} }
#ifdef __int8__
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{}); return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{});
...@@ -179,6 +182,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[]) ...@@ -179,6 +182,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
{ {
return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{}); return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{});
} }
#endif
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -77,7 +77,9 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -77,7 +77,9 @@ int profile_conv_bwd_data(int argc, char* argv[])
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
#ifdef __int8__
using INT8 = int8_t; using INT8 = int8_t;
#endif
using NWC = ck::tensor_layout::convolution::NWC; using NWC = ck::tensor_layout::convolution::NWC;
using NHWC = ck::tensor_layout::convolution::NHWC; using NHWC = ck::tensor_layout::convolution::NHWC;
...@@ -138,10 +140,12 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -138,10 +140,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
{ {
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{}); return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{});
} }
#ifdef __int8__
else if(data_type == ConvDataType::INT8_INT8_INT8) else if(data_type == ConvDataType::INT8_INT8_INT8)
{ {
return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{}); return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{});
} }
#endif
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK)
{ {
...@@ -157,10 +161,12 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -157,10 +161,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
{ {
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{}); return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{});
} }
#ifdef __int8__
else if(data_type == ConvDataType::INT8_INT8_INT8) else if(data_type == ConvDataType::INT8_INT8_INT8)
{ {
return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{}); return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{});
} }
#endif
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK)
{ {
...@@ -176,10 +182,12 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -176,10 +182,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
{ {
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{}); return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{});
} }
#ifdef __int8__
else if(data_type == ConvDataType::INT8_INT8_INT8) else if(data_type == ConvDataType::INT8_INT8_INT8)
{ {
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{}); return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{});
} }
#endif
} }
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -67,11 +67,15 @@ int profile_gemm(int argc, char* argv[]) ...@@ -67,11 +67,15 @@ int profile_gemm(int argc, char* argv[])
const int StrideB = std::stoi(argv[12]); const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]); const int StrideC = std::stoi(argv[13]);
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; #ifdef __bf16__
using BF16 = ck::bhalf_t;
#endif
#ifdef __int8__
using INT8 = int8_t; using INT8 = int8_t;
using INT32 = int32_t; using INT32 = int32_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -149,6 +153,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -149,6 +153,7 @@ int profile_gemm(int argc, char* argv[])
{ {
return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{}); return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
} }
#ifdef __bf16__
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
...@@ -165,6 +170,8 @@ int profile_gemm(int argc, char* argv[]) ...@@ -165,6 +170,8 @@ int profile_gemm(int argc, char* argv[])
{ {
return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
} }
#endif
#ifdef __int8__
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
...@@ -181,6 +188,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -181,6 +188,7 @@ int profile_gemm(int argc, char* argv[])
{ {
return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
} }
#endif
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -77,15 +77,10 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) ...@@ -77,15 +77,10 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using GNHWC = ck::tensor_layout::convolution::GNHWC; using namespace ck::tensor_layout::convolution;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
constexpr auto I2 = ck::Number<2>{}; constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
auto profile = [&](auto num_dim_spatial_tmp, auto profile = [&](auto num_dim_spatial_tmp,
auto out_layout, auto out_layout,
...@@ -116,36 +111,70 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) ...@@ -116,36 +111,70 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
// GNHWC_GKYXC_GNHWK if(num_dim_spatial == 2)
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{});
}
} }
else if(data_type == ConvDataType::BF16_BF16_BF16) else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{});
}
} }
} }
// NHWGC_GKYXC_NHWGK else if(num_dim_spatial == 3)
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, BF16{}, BF16{}, BF16{});
}
} }
else if(data_type == ConvDataType::BF16_BF16_BF16) else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, BF16{}, BF16{}, BF16{});
}
} }
} }
......
...@@ -83,19 +83,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -83,19 +83,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using GNWC = ck::tensor_layout::convolution::GNWC; using namespace ck::tensor_layout::convolution;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
using GKXC = ck::tensor_layout::convolution::GKXC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
constexpr auto I1 = ck::Number<1>{}; constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{}; constexpr auto I2 = ck::Number<2>{};
...@@ -194,6 +182,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -194,6 +182,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -68,7 +68,9 @@ using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>, ...@@ -68,7 +68,9 @@ using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>,
} // namespace } // namespace
TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes); TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes);
#ifdef __fp16
TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); } TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); }
#endif
#ifdef __int8__
TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); } TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
#endif
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_test_executable(test_gemm_fp32 gemm_fp32.cpp) add_test_executable(test_gemm_fp32 gemm_fp32.cpp)
target_link_libraries(test_gemm_fp32 PRIVATE utility) target_link_libraries(test_gemm_fp32 PRIVATE utility)
target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_test_executable(test_gemm_fp16 gemm_fp16.cpp) add_test_executable(test_gemm_fp16 gemm_fp16.cpp)
target_link_libraries(test_gemm_fp16 PRIVATE utility) target_link_libraries(test_gemm_fp16 PRIVATE utility)
target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_bf16 gemm_bf16.cpp)
target_link_libraries(test_gemm_bf16 PRIVATE utility)
target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_int8 gemm_int8.cpp)
target_link_libraries(test_gemm_int8 PRIVATE utility)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
add_library(gemm_standalone_xdl_fp16_instances STATIC add_library(gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp instance/gemm_f16_nt_instance.cpp
...@@ -24,3 +17,14 @@ add_library(gemm_standalone_xdl_fp16_instances STATIC ...@@ -24,3 +17,14 @@ add_library(gemm_standalone_xdl_fp16_instances STATIC
add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp) add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp)
target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility) target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility)
target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/) target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_test_executable(test_gemm_bf16 gemm_bf16.cpp)
target_link_libraries(test_gemm_bf16 PRIVATE utility)
target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_test_executable(test_gemm_int8 gemm_int8.cpp)
target_link_libraries(test_gemm_int8 PRIVATE utility)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
endif()
\ No newline at end of file
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp)
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp) add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp)
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
endif() endif()
\ No newline at end of file
...@@ -85,7 +85,10 @@ using KernelTypes2d = ::testing::Types< ...@@ -85,7 +85,10 @@ using KernelTypes2d = ::testing::Types<
using KernelTypes3d = ::testing::Types< using KernelTypes3d = ::testing::Types<
std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>, std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>, std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>>; std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<float, float, float, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>>;
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment