Unverified Commit 070619fb authored by Shaojie WANG's avatar Shaojie WANG Committed by GitHub
Browse files

[conv bwd-weight]Binding gemm k1 to conv n (#202)



* add some instance to develop

* avoid bank conflicts for wrw for all instance

* add small K1 test

* delete some unused instance

* binding gemm k1 to conv n

* try using half_4 to do ds_read

* reset buffer load oob and ds memcpy to default option

* remove useless instances

* remove redandunt space

* remove printf code

* clang-format-10 change

* use fastest config

* fix clang format for the other files

* remove gemmk0 pad for output

* add gemmk padding macro

* add bank length computation

* add template to distinguish the instance that need lds padding for wrw

* use rocm5.1 as docker

* use integer value for GEMM test

* add Right padding macro

* add 2 test asm code

* using 256x256x32 tile size

* 1. move dedicated transform into gridwisegemm's head file. 2. make lds tensor params a struct templete. 3. remove useless code

* using small vec

* 256*128 kernel size for example

* remove asm files

* use a new gridwise gemm header for bwd-weight

* revert gridwise gemm v2r4r2

* change foramt

* reset gridwise gemm v2r4r2

* remove unused code

* revert instance file

* revert example instance

* format file

* remove macros

* resolve compile error

* rename wrw kernel invoker

* use gridwisegemm pipeline struct instead of implement run fucntion in the same header
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent b31b588d
...@@ -81,6 +81,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -81,6 +81,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number; static constexpr auto GemmK1Number = K1Number;
static constexpr auto N1Number = K1Number;
// Bytes per 32 lds bank: 32 * 4 bytes // Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = 128; static constexpr auto BankLength = 128;
static constexpr auto ElePerBank = BankLength / sizeof(ADataType); static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
...@@ -139,27 +141,51 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -139,27 +141,51 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const index_t N0 = N / N1Number;
out_gemmktotal_gemmm_grid_desc, const index_t GemmK0Total = N0 * Ho * Wo;
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)), const index_t GemmK0S =
make_tuple(Sequence<0>{}, Sequence<1>{}), math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock;
make_tuple(Sequence<0>{}, Sequence<1>{})); const index_t GemmK0Pad = GemmKBatch * GemmK0S;
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K));
const auto out_n0_ho_wo_k_n1_grid_desc =
transform_tensor_descriptor(out_n_ho_wo_k_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
make_pass_through_transform(Ho * Wo),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmk0total_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc,
make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)),
make_pass_through_transform(K),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmk0total_gemmm_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
make_pass_through_transform(GemmM),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmk0pad_gemmm_gemmk1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_pass_through_transform(N1Number)),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
// B: input tensor // B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
...@@ -181,26 +207,50 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -181,26 +207,50 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmktotal_gemmn_grid_desc = const auto in_n0_y_ho_x_wo_c_n1_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_pass_through_transform(Y),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_pass_through_transform(Ho),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_pass_through_transform(X),
make_pass_through_transform(Wo),
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( make_pass_through_transform(C)),
in_gemmktotal_gemmn_grid_desc, make_tuple(Sequence<0>{},
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), Sequence<1>{},
make_pass_through_transform(GemmN)), Sequence<2>{},
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<3>{},
make_tuple(Sequence<0>{}, Sequence<1>{})); Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 6>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_n0_y_ho_x_wo_c_n1_grid_desc,
make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C)),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0total_gemmn_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
make_pass_through_transform(GemmN),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmk0pad_gemmn_gemmk1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_pass_through_transform(N1Number)),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc =
...@@ -456,7 +506,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -456,7 +506,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.N01_)) arg.N01_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting");
} }
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
...@@ -474,21 +524,22 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -474,21 +524,22 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType))); sizeof(CDataType)));
launch_and_time_kernel(stream_config, ave_time =
kernel, launch_and_time_kernel(stream_config,
dim3(grid_size), kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_a_grid_, 0,
arg.p_b_grid_, arg.p_a_grid_,
arg.p_c_grid_, arg.p_b_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_, arg.p_c_grid_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_element_op_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.b_element_op_, arg.a_element_op_,
arg.c_element_op_, arg.b_element_op_,
arg.block_2_ctile_map_); arg.c_element_op_,
arg.block_2_ctile_map_);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
...@@ -592,6 +643,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -592,6 +643,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return false; return false;
} }
// unmerge N to N0 and N1, where N1 equals to K1
if(!(arg.Conv_N_ % K1 == 0))
{
return false;
}
// vector store C matrix into global memory // vector store C matrix into global memory
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
{ {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck { namespace ck {
...@@ -235,8 +236,9 @@ template <index_t BlockSize, ...@@ -235,8 +236,9 @@ template <index_t BlockSize,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL, index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
bool ABlockLdsExtraM1Wrw = false, bool ABlockLdsExtraM1Wrw = false,
bool BBlockLdsExtraN1Wrw = false> bool BBlockLdsExtraN1Wrw = false,
index_t NumGemmKPrefetchStage = 1>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -251,7 +253,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -251,7 +253,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
// M0/M1/M1Padding // M0/M1/M1Padding
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{}; static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
...@@ -511,6 +514,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -511,6 +514,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
// check gridwise gemm pipeline
const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
...@@ -548,9 +559,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -548,9 +559,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = K0 > K0PerBlock; // const bool has_main_k0_block_loop = K0 > K0PerBlock;
const index_t num_loop = K0 / K0PerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
return has_main_k0_block_loop; // return has_main_k0_block_loop;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -771,51 +785,24 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -771,51 +785,24 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS // gridwise GEMM pipeline
{ const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
a_b_k0_m_k1_block_desc,
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy,
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); a_grid_buf,
} a_block_buf,
a_block_slice_copy_step,
// Initialize C b_b_k0_n_k1_grid_desc,
c_thread_buf.Clear(); b_b_k0_n_k1_block_desc,
b_blockwise_copy,
// main body b_grid_buf,
if constexpr(HasMainKBlockLoop) b_block_buf,
{ b_block_slice_copy_step,
index_t k0_block_data_begin = 0; blockwise_gemm,
c_thread_buf,
do K0BlockMainLoop);
{
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// output: register to global memory // output: register to global memory
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment