Commit 45440f17 authored by wangshaojie6's avatar wangshaojie6
Browse files

binding gemm k1 to conv n

parent 0fd4df3b
......@@ -44,20 +44,20 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
256, // BlockSize
128, // MPerBlock
256, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 4, 32, 2>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
4, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
......
......@@ -312,6 +312,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
//static constexpr index_t A_K1_vec = A_K1;// / 2;
//static constexpr index_t B_K1_vec = B_K1;// / 2;
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
......
......@@ -13,6 +13,8 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#define SPLITN0_N1 1
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -81,6 +83,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number;
static constexpr auto N1Number = K1Number;
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
......@@ -125,14 +129,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
#if !SPLITN0_N1
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
#endif
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
// A: output tensor
#if !SPLITN0_N1
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
......@@ -146,7 +153,52 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
#else
const index_t N0 = N / N1Number;
const index_t GemmK0Total = N0 * Ho * Wo;
const index_t GemmK0S = math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmK0Pad = GemmKBatch * GemmK0S;
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
const auto out_n0_ho_wo_k_n1_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
make_pass_through_transform(Ho),
make_pass_through_transform(Wo),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})
);
const auto out_gemmk0total_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n0_ho_wo_k_n1_grid_desc,
make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)),
make_pass_through_transform(K),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})
);
const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmk0total_gemmm_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
make_pass_through_transform(GemmM),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmk0pad_gemmm_gemmk1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
make_pass_through_transform(GemmM),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
#endif
// B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
......@@ -167,6 +219,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
#if !SPLITN0_N1
const auto in_gemmktotal_gemmn_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
......@@ -187,6 +240,47 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
#else
const auto in_n0_y_ho_x_wo_c_n1_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(N0, N1Number)),
make_pass_through_transform(Y),
make_pass_through_transform(Ho),
make_pass_through_transform(X),
make_pass_through_transform(Wo),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(Sequence<0, 6>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_n0_y_ho_x_wo_c_n1_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N0, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C)),
make_pass_through_transform(N1Number)
),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})
);
const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0total_gemmn_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
make_pass_through_transform(GemmN),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0pad_gemmn_gemmk1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
make_pass_through_transform(GemmN),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
#endif
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
......
......@@ -43,7 +43,7 @@ using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple<
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment