"docs/zh_cn/vscode:/vscode.git/clone" did not exist on "262ab794fb52084e4494c281210e652706ce1280"
Commit 5076b38e authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Grouped 3d conv backward data support

parent 1ee99dca
...@@ -46,23 +46,36 @@ class TestGroupedConvndBwdData : public ::testing::Test ...@@ -46,23 +46,36 @@ class TestGroupedConvndBwdData : public ::testing::Test
} }
}; };
using GNHWC = ck::tensor_layout::convolution::GNHWC; using namespace ck::tensor_layout::convolution;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GKYXC = ck::tensor_layout::convolution::GKYXC; using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
std::tuple<float, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
using GNHWK = ck::tensor_layout::convolution::GNHWK; using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
using NHWGK = ck::tensor_layout::convolution::NHWGK; std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
using KernelTypes = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>, template <typename Tuple>
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>, class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData<Tuple>
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>, {
std::tuple<float, NHWGK, GKYXC, NHWGC>, };
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>; template <typename Tuple>
TYPED_TEST_SUITE(TestGroupedConvndBwdData, KernelTypes); class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdData, Test2D) TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
{ {
this->conv_params.clear(); this->conv_params.clear();
...@@ -76,3 +89,15 @@ TYPED_TEST(TestGroupedConvndBwdData, Test2D) ...@@ -76,3 +89,15 @@ TYPED_TEST(TestGroupedConvndBwdData, Test2D)
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>(); this->template Run<2>();
} }
TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment