Commit 37116c98 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Refactor argument preparation

parent 0fb89c4a
......@@ -128,12 +128,10 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
ck::index_t batch_k)
{
using namespace ck;
index_t i_xtilde = tildes[0];
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t X = filter_spatial_lengths[0];
......@@ -142,30 +140,20 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const index_t ConvStrideW = conv_filter_strides[0];
const index_t ConvDilationW = conv_filter_dilations[0];
// const auto K0 = K / K1;
// const auto wei_n_x_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, X, C));
const index_t GemmKTotal = N * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) *
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmK0 * GemmK1Number;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)),
// make_tuple(make_pass_through_transform(N * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K));
......@@ -184,32 +172,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
// const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
// const auto in_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C));
......@@ -228,37 +190,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// const auto wei_n_x_wo_c_grid_desc = transform_tensor_descriptor(
// wei_n_x_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// wei_n_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
......@@ -331,128 +265,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
// out_n_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
// out_n_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
// make_merge_transform(make_tuple(N, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
// in_n_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
// in_n_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_wtildeslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
// wei_k_x_c_grid_desc,
// make_tuple(make_pass_through_transform(K),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
// wei_k_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
}
} // function end
......@@ -468,13 +280,10 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
ck::index_t batch_k)
{
using namespace ck;
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
......@@ -496,35 +305,20 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
// const auto K0 = K / K1;
// const auto out_n_ho_wo_k_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
// const auto wei_k_y_x_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
// const auto in_n_hi_wi_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) *
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmK0 * GemmK1Number;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
// make_tuple(make_pass_through_transform(N * Ho * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
......@@ -543,24 +337,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
// const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_y_ho_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Ho, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C));
......@@ -579,13 +355,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
......@@ -663,184 +432,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
// const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto YTilde = ConvStrideH / GcdStrideDilationH;
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto YDot = math::integer_divide_ceil(Y, YTilde);
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto HTilde =
// Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IHTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IHTildeSliceEnd = math::min(
// HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_ho_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Ho, I0, I0),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
// out_n_hop_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(YDot, HTilde),
// make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
// transform_tensor_descriptor(
// out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5, 6>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
// make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Hi, InLeftPadH, InRightPadH),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
// in_n_hip_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(YTilde, HTilde),
// make_tuple(ConvDilationH, ConvStrideH)),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
// in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_ytilde),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0>{},
// Sequence<>{},
// Sequence<1>{},
// Sequence<>{},
// Sequence<2>{},
// Sequence<3>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_htildeslice_wtildeslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
// wei_k_y_x_c_grid_desc,
// make_tuple(make_pass_through_transform(K),
// make_embed_transform(make_tuple(YDot, YTilde),
// make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
// transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_ytilde),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<3>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(Sequence<0, 1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<4>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
}
} // function end
......@@ -857,14 +448,10 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
ck::index_t batch_k)
{
using namespace ck;
const index_t i_ztilde = tildes[0];
const index_t i_ytilde = tildes[1];
const index_t i_xtilde = tildes[2];
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
......@@ -893,35 +480,20 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
// const auto K0 = K / K1;
// const auto out_n_do_ho_wo_k_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
// const auto wei_k_z_y_x_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
// const auto in_n_di_hi_wi_c_grid_desc =
// make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) *
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmK0 * GemmK1Number;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)),
// make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
......@@ -940,35 +512,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
// const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
// in_n_di_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
// make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
// make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_z_do_y_ho_x_wo_c_grid_desc,
// make_tuple(make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_freeze_transform(I0),
// make_merge_transform(make_tuple(N, Do, Ho, Wo)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<1>{},
// Sequence<3>{},
// Sequence<5>{},
// Sequence<0, 2, 4, 6>{},
// Sequence<7>{}),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
......@@ -987,13 +530,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weights tensor
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
// transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
......@@ -1003,248 +539,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
else
{
// const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
// const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
// const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
// const auto ZTilde = ConvStrideD / GcdStrideDilationD;
// const auto YTilde = ConvStrideH / GcdStrideDilationH;
// const auto XTilde = ConvStrideW / GcdStrideDilationW;
// const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
// const auto YDot = math::integer_divide_ceil(Y, YTilde);
// const auto XDot = math::integer_divide_ceil(X, XTilde);
// const auto DTilde =
// Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
// const auto HTilde =
// Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
// const auto WTilde =
// Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// // only work on HTilde and WTilde that contribute to non-padding area of input tensor
// const auto IDTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
// const auto IHTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
// const auto IWTildeSliceBegin = math::integer_divide_floor(
// math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
// const auto IDTildeSliceEnd = math::min(
// DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
// const auto IHTildeSliceEnd = math::min(
// HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
// const auto IWTildeSliceEnd = math::min(
// WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
// const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
// const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
// const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// // GemmK is different for each GEMM
// const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
// const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
// const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// // A: output tensor
// const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
// out_n_do_ho_wo_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Do, I0, I0),
// make_pad_transform(Ho, I0, I0),
// make_pad_transform(Wo, I0, I0),
// make_pass_through_transform(K)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
// transform_tensor_descriptor(
// out_n_dop_hop_wop_k_grid_desc,
// make_tuple(
// make_pass_through_transform(N),
// make_embed_transform(make_tuple(ZDot, DTilde),
// make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
// make_embed_transform(make_tuple(YDot, HTilde),
// make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, WTilde),
// make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
// make_pass_through_transform(K)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto
// out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
// transform_tensor_descriptor(
// out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_slice_transform(ZDot, I0, ZDotSlice),
// make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_unmerge_transform(make_tuple(K0, K1))),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7, 8>{}));
// const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
// out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
// make_tuple(
// make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
// make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// // B: input tensor
// const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
// in_n_di_hi_wi_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_pad_transform(Di, InLeftPadD, InRightPadD),
// make_pad_transform(Hi, InLeftPadH, InRightPadH),
// make_pad_transform(Wi, InLeftPadW, InRightPadW),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
// transform_tensor_descriptor(
// in_n_dip_hip_wip_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_embed_transform(make_tuple(ZTilde, DTilde),
// make_tuple(ConvDilationD, ConvStrideD)),
// make_embed_transform(make_tuple(YTilde, HTilde),
// make_tuple(ConvDilationH, ConvStrideH)),
// make_embed_transform(make_tuple(XTilde, WTilde),
// make_tuple(ConvDilationW, ConvStrideW)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
// transform_tensor_descriptor(
// in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
// make_tuple(make_pass_through_transform(N),
// make_freeze_transform(i_ztilde),
// make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
// make_freeze_transform(i_ytilde),
// make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
// make_freeze_transform(i_xtilde),
// make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<5>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0>{},
// Sequence<>{},
// Sequence<1>{},
// Sequence<>{},
// Sequence<2>{},
// Sequence<>{},
// Sequence<3>{},
// Sequence<4>{}));
// const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
// in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
// make_tuple(
// make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// // C: weights tensor
// const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
// transform_tensor_descriptor(
// wei_k_z_y_x_c_grid_desc,
// make_tuple(
// make_pass_through_transform(K),
// make_embed_transform(make_tuple(ZDot, ZTilde),
// make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
// make_embed_transform(make_tuple(YDot, YTilde),
// make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
// make_embed_transform(make_tuple(XDot, XTilde),
// make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
// make_pass_through_transform(C)),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
// make_tuple(Sequence<0>{},
// Sequence<1, 2>{},
// Sequence<3, 4>{},
// Sequence<5, 6>{},
// Sequence<7>{}));
// const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
// transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
// make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
// make_slice_transform(ZDot, I0, ZDotSlice),
// make_slice_transform(YDot, I0, YDotSlice),
// make_slice_transform(XDot, I0, XDotSlice),
// make_freeze_transform(i_ztilde),
// make_freeze_transform(i_ytilde),
// make_freeze_transform(i_xtilde),
// make_pass_through_transform(C)),
// make_tuple(Sequence<0>{},
// Sequence<1>{},
// Sequence<3>{},
// Sequence<5>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<6>{},
// Sequence<7>{}),
// make_tuple(Sequence<0, 1>{},
// Sequence<2>{},
// Sequence<3>{},
// Sequence<4>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<>{},
// Sequence<5>{}));
// const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
// wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
// make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
// make_pass_through_transform(C),
// make_pass_through_transform(K1)),
// make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
// in_gemmm_gemmn_grid_desc,
// wei_gemmk0_gemmn_gemmk1_grid_desc);
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
const auto in_n_di_hi_wi_c_grid_desc =
......@@ -1330,14 +624,14 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
......@@ -1353,7 +647,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{0, 0, 0});
1);
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
......@@ -1429,6 +723,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
: p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
......@@ -1443,208 +740,51 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
CreateABCDesc<NDimSpatial>();
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
void CreateABCDesc()
{
const index_t ConvStrideW = conv_filter_strides_[0];
const index_t ConvDilationW = conv_filter_dilations_[0];
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const index_t X = filter_spatial_lengths_[0];
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(XDotSlice <= 0)
{
continue;
}
k_batch_ = 1;
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
{
a_grid_desc_k0_m0_m1_k1_container_.push_back(
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0]));
b_grid_desc_k0_n0_n1_k1_container_.push_back(
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2]));
}
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
void CreateABCDesc()
{
const index_t ConvStrideH = conv_filter_strides_[0];
const index_t ConvStrideW = conv_filter_strides_[1];
const index_t ConvDilationH = conv_filter_dilations_[0];
const index_t ConvDilationW = conv_filter_dilations_[1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const index_t Y = filter_spatial_lengths_[0];
const index_t X = filter_spatial_lengths_[1];
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(YDotSlice * XDotSlice <= 0)
{
continue;
}
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
{
a_grid_desc_k0_m0_m1_k1_container_.push_back(
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0]));
b_grid_desc_k0_n0_n1_k1_container_.push_back(
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2]));
}
}
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
void CreateABCDesc()
{
const index_t ConvStrideD = conv_filter_strides_[0];
const index_t ConvStrideH = conv_filter_strides_[1];
const index_t ConvStrideW = conv_filter_strides_[2];
const index_t ConvDilationD = conv_filter_dilations_[0];
const index_t ConvDilationH = conv_filter_dilations_[1];
const index_t ConvDilationW = conv_filter_dilations_[2];
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const index_t Z = filter_spatial_lengths_[0];
const index_t Y = filter_spatial_lengths_[1];
const index_t X = filter_spatial_lengths_[2];
for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
{
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(ZDotSlice * YDotSlice * XDotSlice <= 0)
{
continue;
}
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
{
a_grid_desc_k0_m0_m1_k1_container_.push_back(
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0]));
b_grid_desc_k0_n0_n1_k1_container_.push_back(
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2]));
}
}
}
}
a_grid_desc_k0_m0_m1_k1_ = GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ = GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ = GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<AGridDesc_K0_M0_M1_K1> a_grid_desc_k0_m0_m1_k1_container_;
std::vector<BGridDesc_K0_N0_N1_K1> b_grid_desc_k0_n0_n1_k1_container_;
std::vector<CGridDesc_M0_M10_M11_N0_N10_N11> c_grid_desc_m0_m10_m11_n0_n10_n11_container_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
std::vector<DefaultBlock2CTileMap> block_2_ctile_map_container_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_;
// element-wise op
OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
......@@ -1657,6 +797,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
index_t k_batch_;
};
// Invoker
......@@ -1664,53 +805,41 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
void ShowInfo(const Argument& arg)
{
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
std::cout << "arg.a_grid_desc_k0_m_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
std::cout << "arg.b_grid_desc_k0_n_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
std::cout << "arg.c_grid_desc_m_n_{ "
<< arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I0)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I1)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I2)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I3)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I4)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5)
<< " ) " << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]))
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
arg.c_grid_desc_m_n_container_[i]);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
......@@ -1728,8 +857,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
has_main_loop,
has_double_loop>;
ave_time +=
launch_and_time_kernel(stream_config,
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
......@@ -1737,39 +865,37 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_container_[i],
arg.b_grid_desc_k0_n0_n1_k1_container_[i],
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i],
arg.block_2_ctile_map_container_[i]);
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
};
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_container_[i].GetLength(I0);
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
return launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
launch_kernel(integral_constant<bool, true>{},
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
launch_kernel(integral_constant<bool, false>{},
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
launch_kernel(integral_constant<bool, false>{},
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
......@@ -1806,26 +932,6 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
}
// // vector load A/B matrix from global memory
// if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
// arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
// arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
// {
// return false;
// }
// // vector store C matrix into global memory
// if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
// {
// return false;
// }
// // Gridwise GEMM size
// return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
// arg.b_grid_desc_kbatch_k0_n_k1_,
// arg.c_grid_desc_m_n_,
// arg.block_2_ctile_map_);
// matrix A
{
auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
......@@ -1877,16 +983,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]))
{
return false;
}
}
return true;
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment