Commit 0fb89c4a authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Update MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N

parent 6f0b21a7
......@@ -142,178 +142,317 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const index_t ConvStrideW = conv_filter_strides[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const auto K0 = K / K1;
// const auto K0 = K / K1;
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
// const auto wei_n_x_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, X, C));
const index_t GemmKTotal = N * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) *
K0PerBlock;
const index_t GemmKPad = GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)),
make_tuple(make_pass_through_transform(N * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)),
// make_tuple(make_pass_through_transform(N * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
// const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
// const auto in_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// const auto wei_n_x_wo_c_grid_desc = transform_tensor_descriptor(
// wei_n_x_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// wei_n_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
else
{
const auto out_n_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K));
const auto wei_k_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X, C));
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K));
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
// A: output tensor
const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
out_n_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
out_n_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_merge_transform(make_tuple(N, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B weight tensor
const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
// B: input tensor
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
const auto in_gemmktotal_gemmn_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X, C)),
make_merge_transform(make_tuple(N, Wo))),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
// out_n_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
// out_n_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
// make_merge_transform(make_tuple(N, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
// in_n_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
// in_n_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_wtildeslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
// wei_k_x_c_grid_desc,
// make_tuple(make_pass_through_transform(K),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
// wei_k_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
}
} // function end
......@@ -357,185 +496,126 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const auto K0 = K / K1;
// const auto K0 = K / K1;
// const auto out_n_ho_wo_k_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
// const auto wei_k_y_x_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
// const auto in_n_hi_wi_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
const auto wei_k_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) *
K0PerBlock;
const index_t GemmKPad = GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
// make_tuple(make_pass_through_transform(N * Ho * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
// const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_y_ho_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Ho, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
else
{
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B weight tensor
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// C: input tensor
// B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
......@@ -545,48 +625,222 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
const auto in_gemmktotal_gemmn_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
// const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto YTilde = ConvStrideH / GcdStrideDilationH;
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto YDot = math::integer_divide_ceil(Y, YTilde);
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto HTilde =
// Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IHTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IHTildeSliceEnd = math::min(
// HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_ho_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Ho, I0, I0),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
// out_n_hop_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(YDot, HTilde),
// make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
// transform_tensor_descriptor(
// out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5, 6>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
// make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Hi, InLeftPadH, InRightPadH),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
// in_n_hip_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(YTilde, HTilde),
// make_tuple(ConvDilationH, ConvStrideH)),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
// in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_ytilde),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0>{},
// Sequence<>{},
// Sequence<1>{},
// Sequence<>{},
// Sequence<2>{},
// Sequence<3>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_htildeslice_wtildeslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
// wei_k_y_x_c_grid_desc,
// make_tuple(make_pass_through_transform(K),
// make_embed_transform(make_tuple(YDot, YTilde),
// make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
// transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_ytilde),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<3>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0, 1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<4>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
}
} // function end
......@@ -639,241 +893,379 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const auto K0 = K / K1;
// const auto K0 = K / K1;
// const auto out_n_do_ho_wo_k_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
// const auto wei_k_z_y_x_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
// const auto in_n_di_hi_wi_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto out_n_do_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
const auto wei_k_z_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) *
K0PerBlock;
const index_t GemmKPad = GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)),
// make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
// const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_di_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
// make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_z_do_y_ho_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Do, Ho, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{},
// Sequence<3>{},
// Sequence<5>{},
// Sequence<0, 2, 4, 6>{},
// Sequence<7>{}),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<0, 2, 4, 6>{},
Sequence<7>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
// const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto ZTilde = ConvStrideD / GcdStrideDilationD;
// const auto YTilde = ConvStrideH / GcdStrideDilationH;
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
// const auto YDot = math::integer_divide_ceil(Y, YTilde);
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto DTilde =
// Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
// const auto HTilde =
// Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IDTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
// const auto IHTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IDTildeSliceEnd = math::min(
// DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
// const auto IHTildeSliceEnd = math::min(
// HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
// const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
// const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_do_ho_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Do, I0, I0),
// make_pad_transform(Ho, I0, I0),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
// transform_tensor_descriptor(
// out_n_dop_hop_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(ZDot, DTilde),
// make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
// make_embed_transform(make_tuple(YDot, HTilde),
// make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto
// out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
// transform_tensor_descriptor(
// out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(ZDot, I0, ZDotSlice),
// make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7, 8>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(
// make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
// make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_di_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Di, InLeftPadD, InRightPadD),
// make_pad_transform(Hi, InLeftPadH, InRightPadH),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
// transform_tensor_descriptor(
// in_n_dip_hip_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(ZTilde, DTilde),
// make_tuple(ConvDilationD, ConvStrideD)),
// make_embed_transform(make_tuple(YTilde, HTilde),
// make_tuple(ConvDilationH, ConvStrideH)),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
// transform_tensor_descriptor(
// in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_ztilde),
// make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
// make_freeze_transform(i_ytilde),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0>{},
// Sequence<>{},
// Sequence<1>{},
// Sequence<>{},
// Sequence<2>{},
// Sequence<>{},
// Sequence<3>{},
// Sequence<4>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
// make_tuple(
// make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
// transform_tensor_descriptor(
// wei_k_z_y_x_c_grid_desc,
// make_tuple(
// make_pass_through_transform(K),
// make_embed_transform(make_tuple(ZDot, ZTilde),
// make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
// make_embed_transform(make_tuple(YDot, YTilde),
// make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
// transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(ZDot, I0, ZDotSlice),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_ztilde),
// make_freeze_transform(i_ytilde),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<3>{},
// Sequence<5>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0, 1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<5>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
// A: output tensor
const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_do_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor(
out_n_dop_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7, 8>{}));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B weight tensor
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor(
wei_k_z_y_x_c_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(ZDot, ZTilde),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ztilde),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2>{},
Sequence<4>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<>{},
Sequence<>{},
Sequence<>{},
Sequence<5>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// C: input tensor
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
......@@ -886,64 +1278,50 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(ZTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<>{},
Sequence<3>{},
Sequence<4>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // function end
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment