Commit add7421f authored by wangshaojie6's avatar wangshaojie6
Browse files

remove gemmk0 pad for output

parent e6b32ffe
......@@ -14,6 +14,7 @@
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#define SPLITN0_N1 1
#define GEMMK0PAD_FOR_OUT 0
namespace ck {
namespace tensor_operation {
......@@ -162,27 +163,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const index_t GemmK0Pad = GemmKBatch * GemmK0S;
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K));
const auto out_n0_ho_wo_k_n1_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
make_pass_through_transform(Ho),
make_pass_through_transform(Wo),
make_pass_through_transform(Ho * Wo),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{})
);
const auto out_gemmk0total_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n0_ho_wo_k_n1_grid_desc,
make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)),
make_pass_through_transform(K),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})
);
#if GEMMK0PAD_FOR_OUT
const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmk0total_gemmm_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
......@@ -198,6 +199,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
#else
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmk0total_gemmm_gemmk1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
make_pass_through_transform(GemmM),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
#endif
#endif
// B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment