"src/include/gridwise_direct_convolution_1.hip.hpp" did not exist on "99d05ba77f6b075852d165d93926dc67cf0cad86"
Unverified Commit 595d23be authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Fix transform and instances for grouped conv bwd data (#848)

* Fix transform and instances for grouped conv bwd data

* Add instances for small K and small C

* Remove workaround after fix

* Fix interface tests
parent eac50708
...@@ -200,9 +200,6 @@ ...@@ -200,9 +200,6 @@
// workaround: compiler issue on gfx908 // workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1 #define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_GITHUB_ISSUE_824 1
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 0
......
...@@ -268,10 +268,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -268,10 +268,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1); const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{ {
return false; return false;
} }
...@@ -289,13 +290,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -289,13 +290,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
} }
// check tile size // check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{ {
return false; return false;
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock; const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
......
...@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
......
...@@ -236,8 +236,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -236,8 +236,6 @@ struct TransformConvBwdDataToGemm_v1
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t AK0 = K / AK1;
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const auto out_grid_desc = const auto out_grid_desc =
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>( make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
...@@ -247,6 +245,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -247,6 +245,8 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t AK0 = math::integer_divide_ceil(K, AK1);
// A: output tensor // A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc, out_grid_desc,
...@@ -308,6 +308,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -308,6 +308,9 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t AK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1);
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
// A: output tensor // A: output tensor
...@@ -332,7 +335,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -332,7 +335,7 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
...@@ -340,7 +343,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -340,7 +343,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -352,21 +355,28 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -352,21 +355,28 @@ struct TransformConvBwdDataToGemm_v1
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5, 6>{})); Sequence<5>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_pass_through_transform(AK1)), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1), make_tuple(AK1, GemmMPerBlock),
Sequence<false, DoPadGemmM, false>{}); Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc; return out_gemmak0_gemmm_gemmak1_grid_desc;
} }
...@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1
Sequence<7>{})); Sequence<7>{}));
const auto const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
...@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -437,22 +447,29 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -437,22 +447,29 @@ struct TransformConvBwdDataToGemm_v1
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}, Sequence<5>{},
Sequence<6>{}, Sequence<6>{},
Sequence<7, 8>{})); Sequence<7>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, AK0)), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))),
make_pass_through_transform(AK1)), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc = const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc, out_gemmk_gemmmraw_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1), make_tuple(AK1, GemmMPerBlock),
Sequence<false, DoPadGemmM, false>{}); Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc; return out_gemmak0_gemmm_gemmak1_grid_desc;
} }
...@@ -505,8 +522,6 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -505,8 +522,6 @@ struct TransformConvBwdDataToGemm_v1
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t BK0 = K / BK1;
// assume packed // assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d // k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C); const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C);
...@@ -515,6 +530,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -515,6 +530,8 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t BK0 = math::integer_divide_ceil(K, BK1);
// B: weight tensor // B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
...@@ -551,6 +568,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -551,6 +568,9 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t BK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1);
// B weight tensor // B weight tensor
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
...@@ -566,10 +586,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -566,10 +586,9 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc = const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_pass_through_transform(K),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde), make_freeze_transform(i_ytilde),
...@@ -581,28 +600,33 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -581,28 +600,33 @@ struct TransformConvBwdDataToGemm_v1
Sequence<2>{}, Sequence<2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}), Sequence<5>{}),
make_tuple(Sequence<0, 1>{}, make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<4>{})); Sequence<3>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc, wei_k_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_pass_through_transform(C), make_pass_through_transform(C)),
make_pass_through_transform(BK1)), make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), make_tuple(BK1, GemmNPerBlock),
GemmNPerBlock, Sequence<true, DoPadGemmN>{});
BK1),
Sequence<false, DoPadGemmN, false>{}); const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemmn_gemmbk1_grid_desc; return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
} }
...@@ -631,10 +655,10 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -631,10 +655,10 @@ struct TransformConvBwdDataToGemm_v1
Sequence<5, 6>{}, Sequence<5, 6>{},
Sequence<7>{})); Sequence<7>{}));
const auto wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc = const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_pass_through_transform(K),
make_slice_transform(ZDot, I0, ZDotSlice), make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
...@@ -650,33 +674,37 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -650,33 +674,37 @@ struct TransformConvBwdDataToGemm_v1
Sequence<4>{}, Sequence<4>{},
Sequence<6>{}, Sequence<6>{},
Sequence<7>{}), Sequence<7>{}),
make_tuple(Sequence<0, 1>{}, make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<4>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<>{}, Sequence<>{},
Sequence<5>{})); Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc, wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple( make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, BK0)), make_pass_through_transform(C)),
make_pass_through_transform(C), make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_pass_through_transform(BK1)), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = const auto wei_gemmk_gemm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, wei_gemmk_gemmnraw_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), make_tuple(BK1, GemmNPerBlock),
GemmNPerBlock, Sequence<true, DoPadGemmN>{});
BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc; const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemm_padded_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemm_gemmbk1_grid_desc;
} }
else else
{ {
......
...@@ -38,7 +38,7 @@ class ContractionInstanceWrapper ...@@ -38,7 +38,7 @@ class ContractionInstanceWrapper
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
// clang-format on // clang-format on
bool isSupported(std::vector<ck::index_t>& ADims, bool isSupported(std::vector<ck::index_t>& ADims,
......
...@@ -87,6 +87,9 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) ...@@ -87,6 +87,9 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
{2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back( this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>(); this->template Run<2>();
} }
...@@ -99,5 +102,11 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) ...@@ -99,5 +102,11 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>(); this->template Run<3>();
} }
...@@ -147,14 +147,14 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -147,14 +147,14 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
32, 32,
4, 4,
2, 2,
S<1, 4, 32, 1>, S<1, 4, 16, 1>,
ABlockTransferThreadClusterArrageOrder, ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value, ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value, ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value, ABlockLdsAddExtraM::value,
S<1, 4, 32, 1>, S<1, 4, 16, 1>,
BBlockTransferThreadClusterArrageOrder, BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value, BBlockTransferSrcVectorDim::value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment