"src/regex_yaml.cpp" did not exist on "620c58abec1172971b98ddd1a1f75e62088dfee1"
Commit 1e5c712b authored by wangshaojie6's avatar wangshaojie6
Browse files

add template to distinguish the instance that need lds padding for wrw

parent 93871ca1
...@@ -244,7 +244,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -244,7 +244,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize, BlockSize,
...@@ -285,7 +287,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -285,7 +287,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
// Argument // Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
......
...@@ -11,9 +11,6 @@ ...@@ -11,9 +11,6 @@
#include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#define A_BLOCK_BANK_CONFLICT_FREE_WRW 1
#define B_BLOCK_BANK_CONFLICT_FREE_WRW 1
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -112,7 +109,9 @@ template <index_t BlockSize, ...@@ -112,7 +109,9 @@ template <index_t BlockSize,
index_t CShuffleMRepeatPerShuffle, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL, index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock> typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
bool ABlockLdsExtraM1Wrw = false,
bool BBlockLdsExtraN1Wrw = false>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -126,21 +125,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -126,21 +125,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// Bytes per 32 lds bank: 32 * 4 bytes // Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = Number<128>{}; static constexpr auto BankLength = Number<128>{};
static constexpr auto ElePerBank = Number<BankLength / sizeof(FloatAB)>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
// M1 & N1 // M1 & M0
static constexpr auto ElePerBank = Number<BankLength / sizeof(FloatAB)>{};
static constexpr auto M1PerBlock = Number<ElePerBank / K1Value>{}; static constexpr auto M1PerBlock = Number<ElePerBank / K1Value>{};
static constexpr auto N1PerBlock = Number<ElePerBank / K1Value>{};
// M0 & N0
static constexpr auto M0PerBlock = Number<MPerBlock / M1PerBlock>{}; static constexpr auto M0PerBlock = Number<MPerBlock / M1PerBlock>{};
static constexpr auto N0PerBlock = Number<NPerBlock / M1PerBlock>{}; static constexpr auto M1Padding = I4;
// M1 padding num // N1 & N0
static constexpr auto M1Padding = Number<4>{}; static constexpr auto N1PerBlock = Number<ElePerBank / K1Value>{};
static constexpr auto N1Padding = M1Padding; static constexpr auto N0PerBlock = Number<NPerBlock / M1PerBlock>{};
static constexpr auto N1Padding = I4;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -150,7 +148,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -150,7 +148,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_desc_k0_m_k1 = [&]() { constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
#if A_BLOCK_BANK_CONFLICT_FREE_WRW if constexpr(ABlockLdsExtraM1Wrw)
{
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor( constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple( make_tuple(
Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1), Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1),
...@@ -169,11 +168,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -169,11 +168,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_block_desc_k0_m_k1_tmp; return a_block_desc_k0_m_k1_tmp;
#else }
else
{
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
#endif }
} }
else else
{ {
...@@ -193,7 +194,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -193,7 +194,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_desc_b_k0_m_k1 = [&]() { constexpr auto a_block_desc_b_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
#if A_BLOCK_BANK_CONFLICT_FREE_WRW if constexpr(ABlockLdsExtraM1Wrw)
{
constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor( constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{}, make_tuple(Number<1>{},
Number<K0PerBlock>{}, Number<K0PerBlock>{},
...@@ -218,14 +220,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -218,14 +220,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return a_block_desc_b_k0_m_k1_tmp; return a_block_desc_b_k0_m_k1_tmp;
#else }
else
{
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1, make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1, Number<MPerBlock + 1>{} * K1,
K1, K1,
I1)); I1));
#endif }
} }
else else
{ {
...@@ -246,7 +250,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -246,7 +250,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto b_block_desc_k0_n_k1 = [&]() { constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
#if B_BLOCK_BANK_CONFLICT_FREE_WRW if constexpr(BBlockLdsExtraN1Wrw)
{
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor( constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple( make_tuple(
Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1), Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1),
...@@ -265,12 +270,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -265,12 +270,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_block_desc_k0_n_k1_tmp; return b_block_desc_k0_n_k1_tmp;
#else }
else
{
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
#endif }
} }
else else
{ {
...@@ -290,7 +296,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -290,7 +296,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto b_block_desc_b_k0_n_k1 = [&]() { constexpr auto b_block_desc_b_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
#if B_BLOCK_BANK_CONFLICT_FREE_WRW if constexpr(BBlockLdsExtraN1Wrw)
{
constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor( constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{}, make_tuple(Number<1>{},
Number<K0PerBlock>{}, Number<K0PerBlock>{},
...@@ -315,14 +322,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -315,14 +322,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return b_block_desc_b_k0_n_k1_tmp; return b_block_desc_b_k0_n_k1_tmp;
#else }
else
{
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1, make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1, Number<NPerBlock + 1>{} * K1,
K1, K1,
I1)); I1));
#endif }
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment