Commit eaa68635 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix RunWrite.

parent 125a39d1
...@@ -101,7 +101,7 @@ __global__ void ...@@ -101,7 +101,7 @@ __global__ void
index_t gemm_tile_id_end = grid_size_grp; index_t gemm_tile_id_end = grid_size_grp;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
typename GridwiseGemm::AccType, typename GridwiseGemm::AccType,
GridwiseGemm::GetMPerXdl() * GridwiseGemm::GetNPerXdl(), GridwiseGemm::GetMXdlPerWave() * GridwiseGemm::GetNXdlPerWave(),
GridwiseGemm::GetCThreadBufferVectorSize(), GridwiseGemm::GetCThreadBufferVectorSize(),
true> true>
results_buffer; results_buffer;
......
...@@ -308,8 +308,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -308,8 +308,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// M0 - MBlock // M0 - MBlock
// M1 - MPerBlock // M1 - MPerBlock
// N0 - NBlock // N0 - NBlock
// N1 - NVecPerThread // N1 - N repeats
// N2 - NVecSize // N2 - NVecSize * cluster length
template <typename EGridDesc_M_N> template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeEGridDescriptor_M0M1_N0N1N2(const EGridDesc_M_N& e_grid_desc_m_n) MakeEGridDescriptor_M0M1_N0N1N2(const EGridDesc_M_N& e_grid_desc_m_n)
...@@ -330,33 +330,18 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -330,33 +330,18 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1(); constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto workspace_thread_desc_m0m1_n0n1n2 = MakeReductionThreadDesc_M0M1_N0N1N2(); constexpr auto workspace_thread_desc_m0m1_n0n1n2 = MakeReductionThreadDesc_M0M1_N0N1N2();
// # of threads in NDim * vector load size * # repeats per thread const auto e_grid_desc_m0m1_n0n1n2 = transform_tensor_descriptor(
constexpr auto NPerBlockPadded = cluster_length_reduce.At(I2) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I3) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4);
constexpr auto NPerBlockPad = NPerBlockPadded - Number<NPerBlock>{};
const auto e_grid_desc_m0m1_n0n1pad = transform_tensor_descriptor(
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform( make_tuple(
make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0)), e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0)),
make_pass_through_transform( make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_pass_through_transform( make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)), e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)),
make_right_pad_transform(Number<NPerBlock>{}, NPerBlockPad)), make_unmerge_transform(make_tuple(workspace_thread_desc_m0m1_n0n1n2.GetLength(I3),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), workspace_thread_desc_m0m1_n0n1n2.GetLength(I4) *
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); cluster_length_reduce.At(I2)))),
const auto e_grid_desc_m0m1_n0n1n2 = transform_tensor_descriptor(
e_grid_desc_m0m1_n0n1pad,
make_tuple(
make_pass_through_transform(e_grid_desc_m0m1_n0n1pad.GetLength(I0)),
make_pass_through_transform(e_grid_desc_m0m1_n0n1pad.GetLength(I1)),
make_pass_through_transform(e_grid_desc_m0m1_n0n1pad.GetLength(I2)),
make_unmerge_transform(make_tuple(
workspace_thread_desc_m0m1_n0n1n2.GetLength(I3) * cluster_length_reduce.At(I2),
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4)))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
...@@ -436,8 +421,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -436,8 +421,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{ {
auto K_t = KBatch * KPerBlock; if(!(K % KPerBlock == 0))
if(!(K % K_t == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
...@@ -540,6 +524,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -540,6 +524,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
} }
const auto k_batch_size =
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KBatch;
if(k_batch_size < KPerBlock)
{
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The k-batch size (" << k_batch_size
<< ") value is less than KPerBlock!\n"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
if(k_batch_size % KPerBlock != 0)
{
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The k-batch size (" << k_batch_size
<< ") value is not a multiple of KPerBlock!\n"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = const auto num_k_loop =
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
...@@ -624,8 +635,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -624,8 +635,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
__device__ __host__ static constexpr auto GetNPerBlock() { return NPerBlock; } __device__ __host__ static constexpr auto GetNPerBlock() { return NPerBlock; }
__device__ __host__ static constexpr auto GetMPerXdl() { return MPerXdl; } __device__ __host__ static constexpr auto GetMXdlPerWave() { return MXdlPerWave; }
__device__ __host__ static constexpr auto GetNPerXdl() { return NPerXdl; } __device__ __host__ static constexpr auto GetNXdlPerWave() { return NXdlPerWave; }
__device__ static constexpr auto GetCThreadBufferVectorSize() __device__ static constexpr auto GetCThreadBufferVectorSize()
{ {
...@@ -646,7 +657,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -646,7 +657,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
return BlockwiseGemmT::xdlops_gemm.GetRegSizePerXdlops(); return BlockwiseGemmT::xdlops_gemm.GetRegSizePerXdlops();
} }
template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf> template <typename Block2ETileMap, typename CThreadBuf>
__device__ static void RunGEMM(const ADataType* __restrict__ p_a_grid, __device__ static void RunGEMM(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -656,7 +667,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -656,7 +667,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const Block2ETileMap& block_2_etile_map, const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf, CThreadBuf& c_thread_buf,
const index_t k_tiles) const index_t k_batch,
const index_t next_k_tiles)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -726,16 +738,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -726,16 +738,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
true, true,
NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
const index_t ak0_start_idx = kbatch_id * AK0PerBlock; const index_t num_k_tiles_per_batch =
const index_t bk0_start_idx = kbatch_id * BK0PerBlock; (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
(KPerBlock * k_batch);
if(blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>()) const index_t ak0_start_idx = kbatch_id * num_k_tiles_per_batch * AK0PerBlock.value;
{ const index_t bk0_start_idx = kbatch_id * num_k_tiles_per_batch * BK0PerBlock.value;
printf("[RunGEMM] bid: %d, ak0_start_idx: %d, bk0_start_idx: %d\n",
static_cast<index_t>(blockIdx.x),
ak0_start_idx,
bk0_start_idx);
}
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -777,25 +784,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -777,25 +784,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock.value, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock.value, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline = GridwiseGemmPipe();
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>(); const index_t num_k_block_main_loop =
__builtin_amdgcn_readfirstlane(next_k_tiles * num_k_tiles_per_batch);
// TODO: what if AK1 != BK1 ??? const bool has_k_block_main_loop =
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(k_tiles); gridwise_gemm_pipeline.CalculateHasMainLoop(num_k_block_main_loop);
// __builtin_amdgcn_readfirstlane((a_grid_desc_ak0_m_ak1.GetLength(I1) *
// a_grid_desc_ak0_m_ak1.GetLength(I3)) /
// KPerBlock);
if(blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
{
printf("[RunGEMM] bid: %d, num_k_block_main_loop %d\n",
static_cast<index_t>(blockIdx.x),
num_k_block_main_loop);
}
bool clear_c_thread_buf = true; bool clear_c_thread_buf = true;
...@@ -813,7 +810,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -813,7 +810,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
KPack, KPack,
LoopSched>(); LoopSched>();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, if(has_k_block_main_loop)
{
gridwise_gemm_pipeline.template Run<true>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
...@@ -830,8 +829,28 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -830,8 +829,28 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
num_k_block_main_loop, num_k_block_main_loop,
clear_c_thread_buf); clear_c_thread_buf);
} }
else
{
gridwise_gemm_pipeline.template Run<false>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop,
clear_c_thread_buf);
}
}
template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf> template <typename Block2ETileMap, typename CThreadBuf>
__device__ static void RunGEMM(const void* __restrict__ p_a_grid_, __device__ static void RunGEMM(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_, const void* __restrict__ p_b_grid_,
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -840,21 +859,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -840,21 +859,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const index_t M, const index_t M,
const index_t N, const index_t N,
const index_t K, const index_t K,
const index_t StrideA, const index_t stride_a,
const index_t StrideB, const index_t stride_b,
const index_t KBatch, const index_t k_batch,
const Block2ETileMap& block_2_etile_map, const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf, CThreadBuf& c_thread_buf,
const index_t k_tiles) const index_t next_k_tiles)
{ {
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_); const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_); const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, StrideA, KBatch); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, stride_a, k_batch);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, StrideB, KBatch); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, stride_b, k_batch);
RunGEMM<HasMainKBlockLoop>(p_a_grid, RunGEMM(p_a_grid,
p_b_grid, p_b_grid,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -863,7 +882,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -863,7 +882,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
block_2_etile_map, block_2_etile_map,
c_thread_buf, c_thread_buf,
k_tiles); k_batch,
next_k_tiles);
} }
template <typename CThreadBuf> template <typename CThreadBuf>
...@@ -1098,24 +1118,24 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1098,24 +1118,24 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>{}; CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>{};
} }
// M0 - 1
// M1 - M elements per thread
// N0 - 1
// N1 - N repeats per thread
// N2 - Vector load/store size
__device__ static constexpr auto MakeReductionThreadDesc_M0M1_N0N1N2() __device__ static constexpr auto MakeReductionThreadDesc_M0M1_N0N1N2()
{ {
constexpr auto cluster_lengths = GetClusterLengthReduction_M0_N0N1(); constexpr auto cluster_lengths = GetClusterLengthReduction_M0_N0N1();
constexpr auto N1_elems =
math::integer_divide_ceil(Number<NPerBlock>{}, cluster_lengths.At(I2));
static_assert(N1_elems % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0,
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1_elems have to be a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock!");
constexpr auto N2 = Number<CDEShuffleBlockTransferScalarPerVector_NPerBlock>{}; constexpr auto N2 = Number<CDEShuffleBlockTransferScalarPerVector_NPerBlock>{};
constexpr auto N1 = math::integer_divide_ceil(N1_elems, N2); constexpr auto N1 = Number<NPerBlock>{} / (Number<cluster_lengths.At(I2)>{} * N2);
constexpr auto M1 = math::integer_divide_ceil(Number<MPerBlock>{}, cluster_lengths.At(I0)); constexpr auto M1 = math::integer_divide_ceil(Number<MPerBlock>{}, cluster_lengths.At(I0));
static_assert( static_assert(
Number<M1>{} * cluster_lengths.At(I0) >= Number<MPerBlock>{}, Number<M1>{} * cluster_lengths.At(I0) == Number<MPerBlock>{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! M1 * cluster_length[0] have to be grater " "Invalid ReductionThreadDesc M0M1_N0N1N2! M1 * cluster_length[0] have to be grater "
"or equal to MPerBlock."); "or equal to MPerBlock.");
static_assert(Number<N1>{} * Number<N2>{} * cluster_lengths.At(I2) >= Number<NPerBlock>{}, static_assert(Number<N1>{} * Number<N2>{} * cluster_lengths.At(I2) == Number<NPerBlock>{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1 * N2 * cluster_length[2] have " "Invalid ReductionThreadDesc M0M1_N0N1N2! N1 * N2 * cluster_length[2] have "
"to be grater or equal to NPerBlock."); "to be grater or equal to NPerBlock.");
...@@ -1129,6 +1149,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1129,6 +1149,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1(); constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce); constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce);
static_assert(ThisThreadBlock::GetNumOfThread() >= reduce_cluster_desc.GetElementSize(),
"Error! ThisThreadBlock::GetNumOfThread() too small");
if(ThisThreadBlock::GetThreadId() >= reduce_cluster_desc.GetElementSize())
{
return;
}
const auto reduce_thread_cluster_idx = const auto reduce_thread_cluster_idx =
reduce_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); reduce_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
...@@ -1139,27 +1168,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1139,27 +1168,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const auto workspace_grid_desc_m0m1_n0n1 = const auto workspace_grid_desc_m0m1_n0n1 =
MakeWorkspaceGridDesc_GridSize_MPerBlock_I1_NPerBlock(get_grid_size()); MakeWorkspaceGridDesc_GridSize_MPerBlock_I1_NPerBlock(get_grid_size());
// # of threads in NDim * vector load size * # repeats per thread const auto workspace_grid_desc_m0m1_n0n1n2 = transform_tensor_descriptor(
constexpr auto NPerBlockPadded = cluster_length_reduce.At(I2) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I3) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4);
constexpr auto NPerBlockPad = NPerBlockPadded - Number<NPerBlock>{};
const auto workspace_grid_desc_m0m1_n0n1pad = transform_tensor_descriptor(
workspace_grid_desc_m0m1_n0n1, workspace_grid_desc_m0m1_n0n1,
make_tuple(make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I0)), make_tuple(
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I0)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I1)), make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I1)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I2)), make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I2)),
make_right_pad_transform(Number<NPerBlock>{}, NPerBlockPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto workspace_grid_desc_m0m1_n0n1n2 = transform_tensor_descriptor(
workspace_grid_desc_m0m1_n0n1pad,
make_tuple(
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1pad.GetLength(I0)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1pad.GetLength(I1)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1pad.GetLength(I2)),
make_unmerge_transform(make_tuple(workspace_thread_desc_m0m1_n0n1n2.GetLength(I3), make_unmerge_transform(make_tuple(workspace_thread_desc_m0m1_n0n1n2.GetLength(I3),
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4) * workspace_thread_desc_m0m1_n0n1n2.GetLength(I4) *
cluster_length_reduce.At(I2)))), cluster_length_reduce.At(I2)))),
...@@ -1255,7 +1269,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1255,7 +1269,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__device__ static void RunWrite(DsGridPointer p_ds_grid, __device__ static void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
/* void* __restrict__ p_shared, */ /* void* __restrict__ p_shared, */
const AccumulationBuffer& acc_buff, AccumulationBuffer& acc_buff,
const index_t M, const index_t M,
const index_t N, const index_t N,
const std::array<index_t, NumDTensor> StrideDs, const std::array<index_t, NumDTensor> StrideDs,
...@@ -1301,9 +1315,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1301,9 +1315,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
auto aux_vgpr_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, EDataType, ScalarPerVector, true>{};
constexpr auto d_vgpr_buf_desc = make_naive_tensor_descriptor_packed( constexpr auto d_vgpr_buf_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1, Number<ScalarPerVector>{})); make_tuple(I1, I1, I1, I1, Number<ScalarPerVector>{}));
...@@ -1312,6 +1323,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1312,6 +1323,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1(); constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce); constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce);
static_assert(ThisThreadBlock::GetNumOfThread() >= reduce_cluster_desc.GetElementSize(),
"Error! ThisThreadBlock::GetNumOfThread() too small");
if(ThisThreadBlock::GetThreadId() >= reduce_cluster_desc.GetElementSize())
{
return;
}
const auto reduce_thread_cluster_idx = const auto reduce_thread_cluster_idx =
reduce_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); reduce_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
...@@ -1344,6 +1363,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1344,6 +1363,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
// Each thread writes consecutive M rows and strided N columns
auto e_grid_store = auto e_grid_store =
ThreadwiseTensorSliceTransfer_v1r3<EDataType, ThreadwiseTensorSliceTransfer_v1r3<EDataType,
EDataType, EDataType,
...@@ -1368,7 +1388,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1368,7 +1388,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto MIter = workspace_thread_desc_m0m1_n0n1n2.GetLength(I1); constexpr auto MIter = workspace_thread_desc_m0m1_n0n1n2.GetLength(I1);
constexpr auto NIter = workspace_thread_desc_m0m1_n0n1n2.GetLength(I3); constexpr auto NIter = workspace_thread_desc_m0m1_n0n1n2.GetLength(I3);
constexpr auto n1_step = cluster_length_reduce.At(I2); constexpr auto n1_step = I1;
constexpr auto d_grid_M1_fwd_step = make_multi_index(I0, I1, I0, I0, I0); constexpr auto d_grid_M1_fwd_step = make_multi_index(I0, I1, I0, I0, I0);
constexpr auto d_grid_N1_fwd_step = make_multi_index(I0, I0, I0, n1_step, I0); constexpr auto d_grid_N1_fwd_step = make_multi_index(I0, I0, I0, n1_step, I0);
...@@ -1410,7 +1430,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1410,7 +1430,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// src_data_refs[I0], // src_data_refs[I0],
// src_data_refs[I1], // src_data_refs[I1],
// ...) // ...)
unpack2(cde_element_op, tie(aux_vgpr_buf(I)), src_data_refs); unpack2(cde_element_op, tie(acc_buff(acc_buf_offset + I)), src_data_refs);
}); });
// if(is_thread_local_1d_id_idx<0, 1, 8, 39>()) // if(is_thread_local_1d_id_idx<0, 1, 8, 39>())
...@@ -1429,18 +1449,22 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1429,18 +1449,22 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// } // }
e_grid_store.Run(workspace_thread_desc_m0m1_n0n1n2, e_grid_store.Run(workspace_thread_desc_m0m1_n0n1n2,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, m_idx, I0, n_idx, I0),
aux_vgpr_buf, // aux_vgpr_buf,
acc_buff,
e_grid_desc_m0m1_n0n1n2, e_grid_desc_m0m1_n0n1n2,
e_grid_buf); e_grid_buf);
if constexpr(NIter != 1)
{
if constexpr(n_idx != (NIter - 1)) if constexpr(n_idx != (NIter - 1))
{ {
static_for<0, NumDTensor, 1>{}([&](auto d_idx) { static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx), ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_N1_fwd_step); d_grid_N1_fwd_step);
}); });
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_N1_fwd_step); e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2,
d_grid_N1_fwd_step);
} }
else else
{ {
...@@ -1448,16 +1472,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1448,16 +1472,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx), ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_N1_bwd_step); d_grid_N1_bwd_step);
}); });
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_N1_bwd_step); e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2,
d_grid_N1_bwd_step);
}
} }
}); // NIter }); // NIter
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) { static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx), ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_M1_fwd_step); d_grid_M1_fwd_step);
}); });
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_M1_fwd_step); e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_M1_fwd_step);
}
}); // MIter }); // MIter
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment