"vscode:/vscode.git/clone" did not exist on "ed305f6b5cc6867a64e7627f74cab8da2c2f5613"
Commit add7421f authored by wangshaojie6's avatar wangshaojie6
Browse files

remove gemmk0 pad for output

parent e6b32ffe
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "gridwise_gemm_xdlops_v2r4r2.hpp" #include "gridwise_gemm_xdlops_v2r4r2.hpp"
#define SPLITN0_N1 1 #define SPLITN0_N1 1
#define GEMMK0PAD_FOR_OUT 0
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -162,27 +163,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -162,27 +163,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const index_t GemmK0Pad = GemmKBatch * GemmK0S; const index_t GemmK0Pad = GemmKBatch * GemmK0S;
const auto out_n_ho_wo_k_grid_desc = const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K));
const auto out_n0_ho_wo_k_n1_grid_desc = transform_tensor_descriptor( const auto out_n0_ho_wo_k_n1_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc, out_n_ho_wo_k_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
make_pass_through_transform(Ho), make_pass_through_transform(Ho * Wo),
make_pass_through_transform(Wo),
make_pass_through_transform(K)), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}) make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{})
); );
const auto out_gemmk0total_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0total_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n0_ho_wo_k_n1_grid_desc, out_n0_ho_wo_k_n1_grid_desc,
make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)), make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)),
make_pass_through_transform(K), make_pass_through_transform(K),
make_pass_through_transform(N1Number)), make_pass_through_transform(N1Number)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}) make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})
); );
#if GEMMK0PAD_FOR_OUT
const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmk0total_gemmm_gemmk1_grid_desc, out_gemmk0total_gemmm_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
...@@ -198,6 +199,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -198,6 +199,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_pass_through_transform(N1Number)), make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
#else
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmk0total_gemmm_gemmk1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
make_pass_through_transform(GemmM),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
#endif
#endif #endif
// B: input tensor // B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment