Commit 6c97a1e2 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into add_navi3x_cmake_option
parents b6029f98 fe96e8fb
...@@ -840,17 +840,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -840,17 +840,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", " << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", " << K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", " << ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", " << BBlockTransferSrcScalarPerVector
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -43,7 +43,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -43,7 +43,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
DataType, DataType,
DataType, DataType,
DataType>(true, // do_verification DataType>(true, // do_verification
1, // init_method integer value 1, // init_method: integer value
false, // do_log false, // do_log
false, // time_kernel false, // time_kernel
param, param,
...@@ -60,9 +60,9 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeight, KernelTypes); ...@@ -60,9 +60,9 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeight, KernelTypes);
TYPED_TEST(TestGroupedConvndBwdWeight, Test1D) TYPED_TEST(TestGroupedConvndBwdWeight, Test1D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
this->conv_params.push_back({1, 4, 64, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
this->template Run<1>(); this->template Run<1>();
} }
...@@ -70,11 +70,11 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test2D) ...@@ -70,11 +70,11 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test2D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back( this->conv_params.push_back(
{2, 4, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back( this->conv_params.push_back(
{2, 4, 8, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{2, 4, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>(); this->template Run<2>();
} }
...@@ -82,10 +82,10 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test3D) ...@@ -82,10 +82,10 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test3D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back( this->conv_params.push_back(
{3, 4, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 4, 8, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 4, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>(); this->template Run<3>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment