Commit 798e93e1 authored by wangshaojie6's avatar wangshaojie6
Browse files

add gemmk padding macro

parent add7421f
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include "gridwise_gemm_xdlops_v2r4r2.hpp" #include "gridwise_gemm_xdlops_v2r4r2.hpp"
#define SPLITN0_N1 1 #define SPLITN0_N1 1
#define GEMMK0PAD_FOR_OUT 0 #define GEMMK0PAD_FOR_OUT 1
#define GEMMK0PAD_FOR_IN 1
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -156,12 +157,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -156,12 +157,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
#else #else
const index_t N0 = N / N1Number; const index_t N0 = N / N1Number;
#if (GEMMK0PAD_FOR_OUT|GEMMK0PAD_FOR_IN)
const index_t GemmK0Total = N0 * Ho * Wo; const index_t GemmK0Total = N0 * Ho * Wo;
const index_t GemmK0S = math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * const index_t GemmK0S = math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmK0Pad = GemmKBatch * GemmK0S; const index_t GemmK0Pad = GemmKBatch * GemmK0S;
#endif
const auto out_n_ho_wo_k_grid_desc = const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K));
...@@ -274,6 +276,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -274,6 +276,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}) make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})
); );
#if GEMMK0PAD_FOR_IN
const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0total_gemmn_gemmk1_grid_desc, in_gemmk0total_gemmn_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
...@@ -289,6 +292,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -289,6 +292,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_pass_through_transform(N1Number)), make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
#else
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0total_gemmn_gemmk1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
make_pass_through_transform(GemmN),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
#endif
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment