Unverified Commit b9b9c3b8 authored by Shaojie WANG's avatar Shaojie WANG Committed by GitHub
Browse files

[Perf][Bwd-weights]Lds re-layout to avoid ds read/write bank conflict and...


[Perf][Bwd-weights]Lds re-layout to avoid ds read/write bank conflict and balance ds ops with address calculations (#190)

* add some instance to develop

* avoid bank conflicts for wrw for all instance

* add small K1 test

* delete some unused instance

* reset buffer load oob and ds memcpy to default option

* remove useless instances

* remove redandunt space

* remove printf code

* clang-format-10 change

* fix clang format for the other files

* add bank length computation

* add template to distinguish the instance that need lds padding for wrw

* use rocm5.1 as docker

* use integer value for GEMM test

* 1. move dedicated transform into gridwisegemm's head file. 2. make lds tensor params a struct templete. 3. remove useless code

* use a new gridwise gemm header for bwd-weight

* revert gridwise gemm v2r4r2

* change foramt

* rename kernel invoker
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent bb4b82a9
......@@ -11,7 +11,7 @@
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
namespace ck {
namespace tensor_operation {
......@@ -81,6 +81,20 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number;
// Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = 128;
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
// M1 & M0
static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
static constexpr auto ABlockLdsM1Padding = 4;
// N1 & N0
static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
static constexpr auto BBlockLdsN1Padding = 4;
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
......@@ -205,7 +219,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
......@@ -233,6 +247,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
......@@ -241,12 +258,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
......@@ -274,6 +296,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
......@@ -282,10 +307,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
......@@ -465,7 +495,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -482,7 +512,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -502,7 +532,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -519,7 +549,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment