Commit 69ccb257 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Add instances for small K and small C

parent 11e39518
...@@ -266,12 +266,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -266,12 +266,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1); const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{ {
return false; return false;
} }
...@@ -289,13 +290,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -289,13 +290,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// check tile size // check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{ {
return false; return false;
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock; const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
......
...@@ -364,20 +364,20 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -364,20 +364,20 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmm_padded_grid_desc =
out_gemmk_gemmmraw_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple( out_gemmk_gemmmraw_grid_desc,
make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(AK1, GemmMPerBlock),
make_pass_through_transform(out_gemmk_gemmmraw_grid_desc.GetLength(I1))), Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
return out_gemmak0_gemmm_gemmak1_grid_desc; return out_gemmak0_gemmm_gemmak1_grid_desc;
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
...@@ -457,20 +457,20 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -457,20 +457,20 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmm_padded_grid_desc =
out_gemmk_gemmmraw_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple( out_gemmk_gemmmraw_grid_desc,
make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(AK1, GemmMPerBlock),
make_pass_through_transform(out_gemmk_gemmmraw_grid_desc.GetLength(I1))), Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
return out_gemmak0_gemmm_gemmak1_grid_desc; return out_gemmak0_gemmm_gemmak1_grid_desc;
} }
else else
...@@ -614,21 +614,20 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -614,21 +614,20 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}), make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmn_padded_grid_desc =
wei_gemmk_gemmnraw_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(C)), make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0),
GemmNPerBlock,
BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc; return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
...@@ -691,22 +690,21 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -691,22 +690,21 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemm_padded_grid_desc =
wei_gemmk_gemmnraw_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), wei_gemmk_gemmnraw_grid_desc,
make_pass_through_transform(C)), make_tuple(BK1, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemm_padded_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = return wei_gemmbk0_gemm_gemmbk1_grid_desc;
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0),
GemmNPerBlock,
BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
} }
else else
{ {
......
...@@ -87,6 +87,9 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) ...@@ -87,6 +87,9 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
{2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back( this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>(); this->template Run<2>();
} }
...@@ -99,5 +102,11 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) ...@@ -99,5 +102,11 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>(); this->template Run<3>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment