"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "749c0e39a698f7c0684ca2cd18180dbec12cffc6"
Unverified Commit 49180fd6 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Grouped 3d conv backward data support (#799)

* Grouped 3d conv backward data support

* Fix comments
parent f82bd593
...@@ -77,15 +77,10 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) ...@@ -77,15 +77,10 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using GNHWC = ck::tensor_layout::convolution::GNHWC; using namespace ck::tensor_layout::convolution;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
constexpr auto I2 = ck::Number<2>{}; constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
auto profile = [&](auto num_dim_spatial_tmp, auto profile = [&](auto num_dim_spatial_tmp,
auto out_layout, auto out_layout,
...@@ -116,36 +111,70 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) ...@@ -116,36 +111,70 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
// GNHWC_GKYXC_GNHWK if(num_dim_spatial == 2)
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{});
}
} }
else if(data_type == ConvDataType::BF16_BF16_BF16) else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{});
}
} }
} }
// NHWGC_GKYXC_NHWGK else if(num_dim_spatial == 3)
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, BF16{}, BF16{}, BF16{});
}
} }
else if(data_type == ConvDataType::BF16_BF16_BF16) else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, BF16{}, BF16{}, BF16{});
}
} }
} }
......
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp)
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp) add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp)
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
endif() endif()
\ No newline at end of file
...@@ -46,23 +46,36 @@ class TestGroupedConvndBwdData : public ::testing::Test ...@@ -46,23 +46,36 @@ class TestGroupedConvndBwdData : public ::testing::Test
} }
}; };
using GNHWC = ck::tensor_layout::convolution::GNHWC; using namespace ck::tensor_layout::convolution;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GKYXC = ck::tensor_layout::convolution::GKYXC; using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
std::tuple<float, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
using GNHWK = ck::tensor_layout::convolution::GNHWK; using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
using NHWGK = ck::tensor_layout::convolution::NHWGK; std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
using KernelTypes = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>, template <typename Tuple>
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>, class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData<Tuple>
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>, {
std::tuple<float, NHWGK, GKYXC, NHWGC>, };
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>; template <typename Tuple>
TYPED_TEST_SUITE(TestGroupedConvndBwdData, KernelTypes); class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdData, Test2D) TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
{ {
this->conv_params.clear(); this->conv_params.clear();
...@@ -76,3 +89,15 @@ TYPED_TEST(TestGroupedConvndBwdData, Test2D) ...@@ -76,3 +89,15 @@ TYPED_TEST(TestGroupedConvndBwdData, Test2D)
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>(); this->template Run<2>();
} }
TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment