Commit 1e5c712b authored by wangshaojie6's avatar wangshaojie6
Browse files

add template to distinguish the instance that need lds padding for wrw

parent 93871ca1
......@@ -244,7 +244,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
......@@ -285,7 +287,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
......
......@@ -11,9 +11,6 @@
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#define A_BLOCK_BANK_CONFLICT_FREE_WRW 1
#define B_BLOCK_BANK_CONFLICT_FREE_WRW 1
namespace ck {
template <typename GridwiseGemm,
......@@ -112,7 +109,9 @@ template <index_t BlockSize,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
bool ABlockLdsExtraM1Wrw = false,
bool BBlockLdsExtraN1Wrw = false>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
static constexpr auto I0 = Number<0>{};
......@@ -126,21 +125,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = Number<128>{};
static constexpr auto ElePerBank = Number<BankLength / sizeof(FloatAB)>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
// M1 & N1
static constexpr auto ElePerBank = Number<BankLength / sizeof(FloatAB)>{};
// M1 & M0
static constexpr auto M1PerBlock = Number<ElePerBank / K1Value>{};
static constexpr auto N1PerBlock = Number<ElePerBank / K1Value>{};
// M0 & N0
static constexpr auto M0PerBlock = Number<MPerBlock / M1PerBlock>{};
static constexpr auto N0PerBlock = Number<NPerBlock / M1PerBlock>{};
static constexpr auto M1Padding = I4;
// M1 padding num
static constexpr auto M1Padding = Number<4>{};
static constexpr auto N1Padding = M1Padding;
// N1 & N0
static constexpr auto N1PerBlock = Number<ElePerBank / K1Value>{};
static constexpr auto N0PerBlock = Number<NPerBlock / M1PerBlock>{};
static constexpr auto N1Padding = I4;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
......@@ -150,30 +148,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
#if A_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(
Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1),
make_tuple(Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
Number<M1PerBlock>{} * K1 + M1Padding,
K1,
I1));
constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_block_desc_k0_m_k1_tmp;
#else
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
#endif
if constexpr(ABlockLdsExtraM1Wrw)
{
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(
Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1),
make_tuple(Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
Number<M1PerBlock>{} * K1 + M1Padding,
K1,
I1));
constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_block_desc_k0_m_k1_tmp;
}
else
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
}
else
{
......@@ -193,39 +194,42 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_desc_b_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
#if A_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{},
Number<K0PerBlock>{},
Number<M0PerBlock>{},
Number<M1PerBlock>{},
K1),
make_tuple(Number<K0PerBlock>{} * Number<M0PerBlock>{} *
(Number<M1PerBlock>{} * K1 + M1Padding),
Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
Number<M1PerBlock>{} * K1 + M1Padding,
K1,
I1));
constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_b_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(
make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return a_block_desc_b_k0_m_k1_tmp;
#else
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1,
K1,
I1));
#endif
if constexpr(ABlockLdsExtraM1Wrw)
{
constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{},
Number<K0PerBlock>{},
Number<M0PerBlock>{},
Number<M1PerBlock>{},
K1),
make_tuple(Number<K0PerBlock>{} * Number<M0PerBlock>{} *
(Number<M1PerBlock>{} * K1 + M1Padding),
Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
Number<M1PerBlock>{} * K1 + M1Padding,
K1,
I1));
constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_b_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(
make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return a_block_desc_b_k0_m_k1_tmp;
}
else
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1,
K1,
I1));
}
}
else
{
......@@ -246,31 +250,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
#if B_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(
Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1),
make_tuple(Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
Number<N1PerBlock>{} * K1 + N1Padding,
K1,
I1));
constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_block_desc_k0_n_k1_tmp;
#else
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
#endif
if constexpr(BBlockLdsExtraN1Wrw)
{
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(
Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1),
make_tuple(Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
Number<N1PerBlock>{} * K1 + N1Padding,
K1,
I1));
constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_block_desc_k0_n_k1_tmp;
}
else
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
}
else
{
......@@ -290,39 +296,42 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto b_block_desc_b_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
#if B_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{},
Number<K0PerBlock>{},
Number<N0PerBlock>{},
Number<N1PerBlock>{},
K1),
make_tuple(Number<K0PerBlock>{} * Number<N0PerBlock>{} *
(Number<N1PerBlock>{} * K1 + N1Padding),
Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
Number<N1PerBlock>{} * K1 + N1Padding,
K1,
I1));
constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_b_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(
make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return b_block_desc_b_k0_n_k1_tmp;
#else
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
#endif
if constexpr(BBlockLdsExtraN1Wrw)
{
constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{},
Number<K0PerBlock>{},
Number<N0PerBlock>{},
Number<N1PerBlock>{},
K1),
make_tuple(Number<K0PerBlock>{} * Number<N0PerBlock>{} *
(Number<N1PerBlock>{} * K1 + N1Padding),
Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
Number<N1PerBlock>{} * K1 + N1Padding,
K1,
I1));
constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_b_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(
make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return b_block_desc_b_k0_n_k1_tmp;
}
else
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
}
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment