Unverified Commit 595d23be authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Fix transform and instances for grouped conv bwd data (#848)

* Fix transform and instances for grouped conv bwd data

* Add instances for small K and small C

* Remove workaround after fix

* Fix interface tests
parent eac50708
......@@ -200,9 +200,6 @@
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_GITHUB_ISSUE_824 1
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
......
......@@ -266,12 +266,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{
return false;
}
......@@ -289,13 +290,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
......
......@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
......
......@@ -236,8 +236,6 @@ struct TransformConvBwdDataToGemm_v1
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t AK0 = K / AK1;
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const auto out_grid_desc =
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
......@@ -247,6 +245,8 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t AK0 = math::integer_divide_ceil(K, AK1);
// A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc,
......@@ -308,6 +308,9 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t AK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1);
if constexpr(NDimSpatial == 2)
{
// A: output tensor
......@@ -332,7 +335,7 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
......@@ -340,7 +343,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -352,21 +355,28 @@ struct TransformConvBwdDataToGemm_v1
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
Sequence<5>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(AK1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
......@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1
Sequence<7>{}));
const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
......@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -437,22 +447,29 @@ struct TransformConvBwdDataToGemm_v1
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7, 8>{}));
Sequence<7>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc,
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, AK0)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(AK1)),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
......@@ -505,8 +522,6 @@ struct TransformConvBwdDataToGemm_v1
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t BK0 = K / BK1;
// assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C);
......@@ -515,6 +530,8 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t BK0 = math::integer_divide_ceil(K, BK1);
// B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
......@@ -551,6 +568,9 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t BK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1);
// B weight tensor
if constexpr(NDimSpatial == 2)
{
......@@ -566,43 +586,47 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<>{},
Sequence<>{},
Sequence<3>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)),
make_pass_through_transform(C),
make_pass_through_transform(BK1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_k_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_pass_through_transform(C)),
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0),
GemmNPerBlock,
BK1),
Sequence<false, DoPadGemmN, false>{});
wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
}
......@@ -631,10 +655,10 @@ struct TransformConvBwdDataToGemm_v1
Sequence<5, 6>{},
Sequence<7>{}));
const auto wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc =
const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_tuple(make_pass_through_transform(K),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
......@@ -650,33 +674,37 @@ struct TransformConvBwdDataToGemm_v1
Sequence<4>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0, 1>{},
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<>{},
Sequence<>{},
Sequence<>{},
Sequence<5>{}));
Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, BK0)),
make_pass_through_transform(C),
make_pass_through_transform(BK1)),
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_pass_through_transform(C)),
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
const auto wei_gemmk_gemm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0),
GemmNPerBlock,
BK1),
Sequence<false, DoPadGemmN, false>{});
wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemm_padded_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemm_gemmbk1_grid_desc;
}
else
{
......
......@@ -47,24 +47,26 @@ using device_grouped_conv_bwd_data_xdl_f16_instances =
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>
#ifndef CK_WORKAROUND_GITHUB_ISSUE_824
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// TODO: After enable, add instance for small conv.K and conv.C
#endif
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
......@@ -75,33 +77,34 @@ template <index_t NDimSpatial,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bf16_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>
#ifndef CK_WORKAROUND_GITHUB_ISSUE_824
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// TODO: After enable, add instance for small conv.K and conv.C
#endif
// clang-format on
>;
using device_grouped_conv_bwd_data_xdl_bf16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
// f32_f32_f32_f32
template <index_t NDimSpatial,
......@@ -117,24 +120,26 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>
#ifndef CK_WORKAROUND_GITHUB_ISSUE_824
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// TODO: After enable, add instance for small conv.K and conv.C
#endif
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
......
......@@ -38,7 +38,7 @@ class ContractionInstanceWrapper
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
// clang-format on
bool isSupported(std::vector<ck::index_t>& ADims,
......
......@@ -87,6 +87,9 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
{2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
}
......@@ -99,5 +102,11 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}
......@@ -147,14 +147,14 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
32,
4,
2,
S<1, 4, 32, 1>,
S<1, 4, 16, 1>,
ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value,
S<1, 4, 32, 1>,
S<1, 4, 16, 1>,
BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment