Commit 35978330 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent d3bd5922
......@@ -17,7 +17,7 @@ template <typename GridwiseGemm,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
......@@ -33,8 +33,8 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
......@@ -49,7 +49,7 @@ __global__ void
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
a_element_op,
b_element_op,
c_element_op,
......@@ -74,8 +74,8 @@ template <
index_t MPerXdl,
index_t NPerXdl,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -92,9 +92,9 @@ template <
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
......@@ -110,8 +110,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
// TODO: need to calculate LDS usage for C shuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
......@@ -130,6 +129,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
}();
return a_block_desc_k0_m_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
......@@ -145,14 +151,55 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
}();
return b_block_desc_k0_n_k1;
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle>{},
Number<MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle>{},
Number<NWave * NPerXdl>{}));
return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
}
// TODO: need to calculate LDS usage for C shuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
// LDS allocation for C shuffle in LDS
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
constexpr auto c_block_size =
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatC));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -166,8 +213,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXdl * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXdl)) == 0,
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
......@@ -215,7 +262,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
......@@ -224,20 +271,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
const auto c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(MBlock, Number<MRepeat>{}, Number<MWave * MPerXdl>{})),
make_unmerge_transform(
make_tuple(NBlock, Number<NRepeat>{}, Number<NWave * NPerXdl>{}))),
make_tuple(make_unmerge_transform(make_tuple(
MBlock, Number<MXdlPerWave>{}, Number<MWave * MPerXdl>{})),
make_unmerge_transform(make_tuple(
NBlock, Number<NXdlPerWave>{}, Number<NWave * NPerXdl>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl;
return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
}
// return block_id to C matrix tile idx (m0, n0) mapping
......@@ -275,9 +322,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl =
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>;
using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
......@@ -290,8 +337,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl&
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl&
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
......@@ -303,7 +350,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
......@@ -323,34 +370,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy
auto a_blockwise_copy =
......@@ -428,21 +451,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
decltype(b_block_desc_k0_n_k1),
MPerXdl,
NPerXdl,
MRepeat,
NRepeat,
MXdlPerWave,
NXdlPerWave,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size,
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......@@ -496,12 +519,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// shuffle C and write out
{
static_assert(MRepeat % CShuffleMRepeatPerShuffle == 0 &&
NRepeat % CShuffleNRepeatPerShuffle == 0,
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
......@@ -521,29 +544,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<CShuffleMRepeatPerShuffle>{},
Number<MWave * MPerXdl>{},
I1,
Number<CShuffleNRepeatPerShuffle>{},
Number<NWave * NPerXdl>{}));
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_tuple(make_freeze_transform(I0), // freeze mblock
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
make_tuple(
make_freeze_transform(I0), // freeze mblock
make_pass_through_transform(
Number<CShuffleMRepeatPerShuffle>{}), // M0 (MRepeat) per shuffle
Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
make_unmerge_transform(
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform(I0), // freeze nblock
make_pass_through_transform(
Number<CShuffleNRepeatPerShuffle>{}), // N0 (NRepeat) per shuffle
Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
make_unmerge_transform(
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple(Sequence<0>{},
......@@ -596,8 +615,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
......@@ -626,48 +645,52 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle,
CShuffleMXdlPerWavePerShuffle,
MWave * MPerXdl,
1,
CShuffleNRepeatPerShuffle,
CShuffleNXdlPerWavePerShuffle,
NWave * NPerXdl>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
decltype(c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl),
decltype(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
{c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
make_multi_index(0, 0, 0, 0, 0, 0),
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
c_element_op};
constexpr auto mrepeat_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle, 0, 0, 0, 0);
constexpr auto nrepeat_forward_step =
make_multi_index(0, 0, 0, 0, CShuffleNRepeatPerShuffle, 0);
constexpr auto nrepeat_backward_step =
make_multi_index(0, 0, 0, 0, -CShuffleNRepeatPerShuffle, 0);
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
constexpr auto nxdlperwave_forward_step =
make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
constexpr auto nxdlperwave_backward_step =
make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mrepeat_iter) {
constexpr auto mrepeat = mrepeat_iter;
static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nrepeat_iter) {
constexpr bool nrepeat_forward_sweep =
(mrepeat % (2 * CShuffleMRepeatPerShuffle) == 0);
static_for<0,
NXdlPerWave,
CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
constexpr index_t nrepeat_value =
nrepeat_forward_sweep
? nrepeat_iter
: (NRepeat - nrepeat_iter - CShuffleNRepeatPerShuffle);
constexpr index_t nxdlperwave_value =
nxdlperwave_forward_sweep
? nxdlperwave_iter
: (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
constexpr auto nrepeat = Number<nrepeat_value>{};
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
......@@ -675,7 +698,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(mrepeat, nrepeat, I0, I0, I0, I0, I0, I0),
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
......@@ -685,33 +708,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// LDS to global
c_block_copy_lds_to_global.Run(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
c_block_buf,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
c_grid_buf);
// move on nrepeat dimension
if constexpr(nrepeat_forward_sweep &&
(nrepeat < NRepeat - CShuffleNRepeatPerShuffle))
// move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep &&
(nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
nrepeat_forward_step);
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
nxdlperwave_forward_step);
}
else if constexpr((!nrepeat_forward_sweep) && (nrepeat > 0))
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
nrepeat_backward_step);
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
nxdlperwave_backward_step);
}
});
// move on mrepeat dimension
if constexpr(mrepeat < MRepeat - CShuffleMRepeatPerShuffle)
// move on mxdlperwave dimension
if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
mrepeat_forward_step);
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
mxdlperwave_forward_step);
}
});
}
......
......@@ -20,10 +20,10 @@ using PassThrough_v2 = ck::tensor_operation::element_wise::PassThrough;
using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances =
std::tuple<
// clang-format off
//##########################################################################| InData| WeiData| OutData| AccData| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeate| NRepeate| _MBlock_MRepeat_MWaveMPerXdl| ScalarPerVector|
//##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
......
......@@ -48,9 +48,9 @@ template <
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
......@@ -249,9 +249,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
CBlockTransferScalarPerVector_NWaveNPerXdl>;
// Argument
......@@ -281,7 +281,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_{},
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
......@@ -308,9 +308,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
{
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_ =
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
......@@ -325,8 +325,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_;
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
......@@ -355,23 +355,24 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout
<< "arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_{ "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
<< "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_"
"nwavenperxdl_{ "
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I0)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I1)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I2)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I3)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I4)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I5)
<< "}" << std::endl;
}
......@@ -404,7 +405,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl>,
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......@@ -422,7 +423,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
......@@ -438,7 +439,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl>,
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......@@ -456,7 +457,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
......
......@@ -34,8 +34,8 @@ using DeviceConvFwdInstance = ck::tensor_operation::device::
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeate| NRepeate| _MBlock_MRepeat_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat_NWaveNPerXdl| _NWaveNPerXdl|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>;
// clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment