Commit 98def248 authored by Adam Osewski's avatar Adam Osewski
Browse files

Rework RunWrite.

parent bbd26e10
......@@ -157,10 +157,12 @@ __global__ void
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock())
{
GridwiseGemm::StorePartials(p_workspace, results_buffer);
}
// With cshuffle at store partials all workgroups have to store
// their partials to workspace gmem.
// TODO: The reduction workgroup don't have to store it's own results to GMEM!
// Would be enough to keep it in registers and during AccumulatePartials
// do CShuffle in flight with loading partials products of other peer workgroups.
GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer);
work_scheduler.FlagFinished();
......@@ -171,10 +173,20 @@ __global__ void
index_t neighbour_count =
work_scheduler.WaitForNeighbours(k_batch, b2c_tile_map.GetTileKIdx());
constexpr auto workspace_thread_desc_m0m1_n0n1n2 =
GridwiseGemm::MakeReductionThreadDesc_M0M1_N0N1N2();
StaticBuffer<AddressSpaceEnum::Vgpr,
typename GridwiseGemm::CShuffleDataT,
workspace_thread_desc_m0m1_n0n1n2.GetElementSpaceSize(),
true>
acc_buff{};
acc_buff.Clear();
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if(neighbour_count > 0)
GridwiseGemm::AccumulatePartials(p_workspace, results_buffer, neighbour_count + 1);
GridwiseGemm::AccumulatePartials(p_workspace, acc_buff, neighbour_count + 1);
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(neighbour_count);
......@@ -195,17 +207,17 @@ __global__ void
GridwiseGemm::template RunWrite(p_ds_grid,
p_e_grid,
static_cast<void*>(p_shared),
acc_buff,
M,
N,
stride_ds,
stride_e,
cde_element_op,
b2c_tile_map,
results_buffer);
b2c_tile_map);
}
else if(work_scheduler.HasTile())
{
// TODO Move this just before StorePartials!
work_scheduler.WaitForReduction();
}
} while(work_scheduler.HasTile());
......@@ -757,7 +769,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<< ", grid_size: " << grid_size << ", flag_count: " << flag_count
<< ", p_flags: " << p_flags << ", workspace_ptr: " << dev_gemm_workspace
<< ", acc_workspace_size_bytes: " << acc_workspace_size_bytes
<< std::endl;
<< ", kbatch: " << arg.K_BATCH << std::endl;
}
auto preprocess = [&]() {
......@@ -995,7 +1007,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// the amount of workspace bytes needed, may be less due to the number of available CUs in
// stream used to launch kernel.
size_t size_bytes =
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(AccDataType), grid_size) +
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) +
flag_count * sizeof(uint32_t);
return size_bytes;
}
......
......@@ -10,12 +10,13 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck {
......@@ -80,6 +81,13 @@ template <typename ADataType,
PipelineVersion PipelineVer>
class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
template <index_t... Ids>
__device__ static bool is_thread_local_1d_id_idx()
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ...);
}
static constexpr index_t NumDTensor = DsDataType::Size();
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
......@@ -106,7 +114,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
public:
using AccType = AccDataType;
using AccType = AccDataType;
using CShuffleDataT = CShuffleDataType;
__host__ __device__ static auto CalculateMPadded(index_t M)
{
......@@ -327,10 +336,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
c_block_size * sizeof(CShuffleDataType));
}
// E desc for destination in blockwise copy
// M0 - MBlock
// M1 - MPerBlock
// N0 - NBlock
// N1 - NVecPerThread
// N2 - NVecSize
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
MakeEGridDescriptor_M0M1_N0N1N2(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
......@@ -345,18 +358,49 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto workspace_thread_desc_m0m1_n0n1n2 = MakeReductionThreadDesc_M0M1_N0N1N2();
// # of threads in NDim * vector load size * # repeats per thread
constexpr auto NPerBlockPadded = cluster_length_reduce.At(I2) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I3) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4);
constexpr auto NPerBlockPad = NPerBlockPadded - Number<NPerBlock>{};
const auto e_grid_desc_m0m1_n0n1pad = transform_tensor_descriptor(
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0)),
make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_pass_through_transform(
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)),
make_right_pad_transform(Number<NPerBlock>{}, NPerBlockPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto e_grid_desc_m0m1_n0n1n2 = transform_tensor_descriptor(
e_grid_desc_m0m1_n0n1pad,
make_tuple(
make_pass_through_transform(e_grid_desc_m0m1_n0n1pad.GetLength(I0)),
make_pass_through_transform(e_grid_desc_m0m1_n0n1pad.GetLength(I1)),
make_pass_through_transform(e_grid_desc_m0m1_n0n1pad.GetLength(I2)),
make_unmerge_transform(make_tuple(
workspace_thread_desc_m0m1_n0n1n2.GetLength(I3) * cluster_length_reduce.At(I2),
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4)))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
return e_grid_desc_m0m1_n0n1n2;
}
// Ds desc for source in blockwise copy
template <typename DsGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
MakeDsGridDescriptor_M0M1_N0N1N2(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
[&](auto i) { return MakeEGridDescriptor_M0M1_N0N1N2(ds_grid_desc_m_n[i]); },
Number<NumDTensor>{});
}
......@@ -600,20 +644,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
__host__ __device__ static auto
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(index_t grid_size)
MakeWorkspaceGridDesc_GridSize_MPerBlock_I1_NPerBlock(index_t grid_size)
{
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(
make_tuple(grid_size, I1.value, MPerBlock, NPerBlock),
make_tuple(MPerBlock * NPerBlock, MPerBlock * NPerBlock, NPerBlock, I1.value));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(
make_tuple(grid_size, I1.value, MPerBlock, NPerBlock),
make_tuple(MPerBlock * NPerBlock, MPerBlock * NPerBlock, I1.value, MPerBlock));
}
return make_naive_tensor_descriptor(
make_tuple(grid_size, MPerBlock, I1.value, NPerBlock),
make_tuple(MPerBlock * NPerBlock, NPerBlock, NPerBlock, I1.value));
}
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
......@@ -850,21 +885,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
c_thread_buf);
}
// TODO Need to do CShuffle already here:
template <typename CThreadBuf>
__device__ static void StorePartials(void* __restrict__ p_workspace,
void* __restrict__ p_shared,
const CThreadBuf& c_thread_buf)
{
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const auto workspace_grid_desc_m0_n0_m1_n1 =
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(get_grid_size());
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......@@ -880,161 +905,10 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
KPack,
LoopSched>())>;
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// M0 = grid_size -> MRepeats (MXdlPerWave)
// N0 = 1 -> NRepeats (NXdlPerWave)
const auto workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1,
make_tuple(make_pass_through_transform(w_grid_m0),
make_pass_through_transform(w_grid_n0),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4, 6, 7, 8>{}, Sequence<3, 5, 9>{}));
const auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(make_merge_transform(make_tuple(w_grid_m0, M0)), // MRepeats (grid)
make_merge_transform(make_tuple(w_grid_n0, N0)), // NRepeats (grid)
make_pass_through_transform(M1), // MWave
make_pass_through_transform(N1), // NWave
make_pass_through_transform(M2), // mfma_instr.num_groups_per_blk
make_pass_through_transform(M3), // mfma_instr.num_input_blks
make_pass_through_transform(M4), // mfma_instr.group_size
make_pass_through_transform(N2)), // mfma_instr.num_threads_per_blk
make_tuple(Sequence<0, 2>{},
Sequence<1, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{},
Sequence<8>{},
Sequence<9>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_workspace_grid,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
const auto c_thread_mtx_on_block =
BlockwiseGemmT::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
auto c_thread_copy_vgpr_to_gmem = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
AccDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLengths()), // SliceLengths
// N -> then M dims
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // DimAccessOrder
7, // DstVectorDim,
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index((static_cast<index_t>(blockIdx.x)) * MXdlPerWave,
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
c_thread_copy_vgpr_to_gmem.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
w_grid_buf);
}
template <typename CThreadBuf>
__device__ static void AccumulatePartials(void* __restrict__ p_workspace,
CThreadBuf& c_thread_buf,
uint32_t reduce_count)
{
using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>())>;
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(),
true>
acc_buf{};
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const auto workspace_grid_desc_m0_n0_m1_n1 =
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(get_grid_size());
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"MXdlPerWave % CShuffleMXdlPerWavePerShuffle != 0 or "
"NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0,");
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
......@@ -1048,213 +922,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// M0 = grid_size -> MRepeats (MXdlPerWave)
// N0 = 1 -> NRepeats (NXdlPerWave)
const auto workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1,
make_tuple(make_pass_through_transform(w_grid_m0),
make_pass_through_transform(w_grid_n0),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4, 6, 7, 8>{}, Sequence<3, 5, 9>{}));
const auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(make_merge_transform(make_tuple(w_grid_m0, M0)), // MRepeats (grid)
make_merge_transform(make_tuple(w_grid_n0, N0)), // NRepeats (grid)
make_pass_through_transform(M1), // MWave
make_pass_through_transform(N1), // NWave
make_pass_through_transform(M2), // mfma_instr.num_groups_per_blk
make_pass_through_transform(M3), // mfma_instr.num_input_blks
make_pass_through_transform(M4), // mfma_instr.group_size
make_pass_through_transform(N2)), // mfma_instr.num_threads_per_blk
make_tuple(Sequence<0, 2>{},
Sequence<1, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{},
Sequence<8>{},
Sequence<9>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
const auto c_thread_mtx_on_block =
BlockwiseGemmT::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_workspace_grid,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2), // SrcDesc,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), // DstDesc,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLengths()), // SliceLengths,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // DimAccessOrder,
7, // SrcVectorDim,
1, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
make_multi_index((static_cast<index_t>(blockIdx.x) + 1) * MXdlPerWave,
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2])};
using Accumulation =
ck::detail::AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
constexpr auto partial_acc_load_step =
make_multi_index(MXdlPerWave, I0, I0, I0, I0, I0, I0, I0);
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
for(uint32_t i_t = 1; i_t < reduce_count; ++i_t)
{
acc_buf.Clear();
acc_load.Run(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
w_grid_buf,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
acc_buf);
static_for<0, c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(), 1>{}(
[&](auto i_vec) { Accumulation::Calculate(c_thread_buf(i_vec), acc_buf[i_vec]); });
acc_load.MoveSrcSliceWindow(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
partial_acc_load_step);
}
}
template <typename Block2ETileMap, typename CThreadBuf>
__device__ static void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const index_t M,
const index_t N,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const CDEElementwiseOperation& cde_element_op,
const Block2ETileMap& block_2_etile_map,
const CThreadBuf& c_thread_buf)
{
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
DsGridDesc_M_N ds_grid_desc_m_n;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[j]);
});
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
});
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// shuffle C and write out
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// divide block work by [M, N, K]
const auto block_work_idx = block_2_etile_map.GetBottomIndex();
using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>())>;
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
......@@ -1281,6 +951,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
const auto c_thread_mtx_on_block =
BlockwiseGemmT::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
......@@ -1338,32 +1011,44 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// space filling curve for threadwise C in VGPR before shuffle
// M0 = grid_size
// M1 = MPerBlock
// N0 = 1
// N1 = NPerBlock
const auto workspace_grid_desc_m0_m1_n0_n1 =
MakeWorkspaceGridDesc_GridSize_MPerBlock_I1_NPerBlock(get_grid_size());
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_workspace_grid, workspace_grid_desc_m0_m1_n0_n1.GetElementSpaceSize());
// shuffle: blockwise copy C from LDS to workspace
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
ck::tensor_operation::element_wise::PassThrough, // ElementwiseOperation,
InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(workspace_grid_desc_m0_m1_n0_n1),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
workspace_grid_desc_m0_m1_n0_n1,
make_multi_index(static_cast<index_t>(blockIdx.x), 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
......@@ -1376,8 +1061,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
// space filling curve for shuffled blockwise W in global mem
constexpr auto sfc_w_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
......@@ -1386,39 +1071,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
static_assert(num_access == sfc_w_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
......@@ -1435,28 +1088,348 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
workspace_grid_desc_m0_m1_n0_n1,
w_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step = sfc_cde_block.GetForwardStep(access_id);
constexpr auto w_global_step = sfc_w_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
workspace_grid_desc_m0_m1_n0_n1, w_global_step);
}
});
}
__device__ static constexpr auto GetClusterLengthReduction_M0_N0N1()
{
return Sequence<CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1),
I1.value,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>{};
}
__device__ static constexpr auto MakeReductionThreadDesc_M0M1_N0N1N2()
{
constexpr auto cluster_lengths = GetClusterLengthReduction_M0_N0N1();
constexpr auto N1_elems =
math::integer_divide_ceil(Number<NPerBlock>{}, cluster_lengths.At(I2));
static_assert(N1_elems % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0,
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1_elems have to be a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock!");
constexpr auto N2 = Number<CDEShuffleBlockTransferScalarPerVector_NPerBlock>{};
constexpr auto N1 = math::integer_divide_ceil(N1_elems, N2);
constexpr auto M1 = math::integer_divide_ceil(Number<MPerBlock>{}, cluster_lengths.At(I0));
static_assert(
Number<M1>{} * cluster_lengths.At(I0) >= Number<MPerBlock>{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! M1 * cluster_length[0] have to be grater "
"or equal to MPerBlock.");
static_assert(Number<N1>{} * Number<N2>{} * cluster_lengths.At(I2) >= Number<NPerBlock>{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1 * N2 * cluster_length[2] have "
"to be grater or equal to NPerBlock.");
return make_naive_tensor_descriptor_packed(make_tuple(I1, Number<M1>{}, I1, N1, N2));
}
template <typename AccumulationBuffer>
__device__ static void AccumulatePartials(void* __restrict__ p_workspace,
AccumulationBuffer& acc_buff,
uint32_t reduce_count)
{
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n0_cluster_id = reduce_thread_cluster_idx[I1]; // Should be I0
const auto thread_n1_cluster_id = reduce_thread_cluster_idx[I2];
constexpr auto workspace_thread_desc_m0m1_n0n1n2 = MakeReductionThreadDesc_M0M1_N0N1N2();
const auto workspace_grid_desc_m0m1_n0n1 =
MakeWorkspaceGridDesc_GridSize_MPerBlock_I1_NPerBlock(get_grid_size());
// # of threads in NDim * vector load size * # repeats per thread
constexpr auto NPerBlockPadded = cluster_length_reduce.At(I2) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I3) *
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4);
constexpr auto NPerBlockPad = NPerBlockPadded - Number<NPerBlock>{};
const auto workspace_grid_desc_m0m1_n0n1pad = transform_tensor_descriptor(
workspace_grid_desc_m0m1_n0n1,
make_tuple(make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I0)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I1)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1.GetLength(I2)),
make_right_pad_transform(Number<NPerBlock>{}, NPerBlockPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto workspace_grid_desc_m0m1_n0n1n2 = transform_tensor_descriptor(
workspace_grid_desc_m0m1_n0n1pad,
make_tuple(
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1pad.GetLength(I0)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1pad.GetLength(I1)),
make_pass_through_transform(workspace_grid_desc_m0m1_n0n1pad.GetLength(I2)),
make_unmerge_transform(make_tuple(workspace_thread_desc_m0m1_n0n1n2.GetLength(I3),
workspace_thread_desc_m0m1_n0n1n2.GetLength(I4) *
cluster_length_reduce.At(I2)))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
CShuffleDataType,
workspace_thread_desc_m0m1_n0n1n2.GetElementSpaceSize(),
true>
partial_acc_buf{};
auto p_workspace_grid = reinterpret_cast<CShuffleDataType*>(p_workspace);
auto w_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_workspace_grid, workspace_grid_desc_m0m1_n0n1n2.GetElementSpaceSize());
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
CShuffleDataType, // SrcData,
CShuffleDataType, // DstData,
decltype(workspace_grid_desc_m0m1_n0n1n2), // SrcDesc,
decltype(workspace_thread_desc_m0m1_n0n1n2), // DstDesc,
decltype(workspace_thread_desc_m0m1_n0n1n2.GetLengths()), // SliceLengths,
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder,
4, // SrcVectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{workspace_grid_desc_m0m1_n0n1n2,
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
// make_multi_index((static_cast<index_t>(blockIdx.x) + 1),
// We want to have a thread raked access pattern
make_multi_index(
static_cast<index_t>(blockIdx.x),
thread_m_cluster_id * workspace_thread_desc_m0m1_n0n1n2.GetLength(I1),
I0,
thread_n0_cluster_id,
thread_n1_cluster_id * workspace_thread_desc_m0m1_n0n1n2.GetLength(I4))};
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, CShuffleDataType>;
constexpr auto partial_acc_load_step = make_multi_index(I1, I0, I0, I0, I0);
// TODO: We do not need to read this workgroup partial results since they're
// already in c_thread_buff
for(uint32_t i_t = 0; i_t < reduce_count; ++i_t)
{
partial_acc_buf.Clear();
acc_load.Run(workspace_grid_desc_m0m1_n0n1n2,
w_grid_buf,
workspace_thread_desc_m0m1_n0n1n2,
make_tuple(I0, I0, I0, I0, I0),
partial_acc_buf);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
static_for<0, workspace_thread_desc_m0m1_n0n1n2.GetElementSpaceSize(), 1>{}(
[&](auto i_vec) {
Accumulation::Calculate(acc_buff(i_vec), partial_acc_buf[i_vec]);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
acc_load.MoveSrcSliceWindow(workspace_grid_desc_m0m1_n0n1n2, partial_acc_load_step);
}
}
template <typename Block2ETileMap, typename AccumulationBuffer>
__device__ static void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
/* void* __restrict__ p_shared, */
const AccumulationBuffer& acc_buff,
const index_t M,
const index_t N,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const CDEElementwiseOperation& cde_element_op,
const Block2ETileMap& block_2_etile_map)
{
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_M0M1_N0N1N2(DsGridDesc_M_N{}))>;
constexpr index_t ScalarPerVector = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
DsGridDesc_M_N ds_grid_desc_m_n;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_m0m1_n0n1n2;
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
const auto e_grid_desc_m0m1_n0n1n2 = MakeEGridDescriptor_M0M1_N0N1N2(e_grid_desc_m_n);
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_m0m1_n0n1n2.GetElementSpaceSize());
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[j]);
});
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_m0m1_n0n1n2(j) = MakeEGridDescriptor_M0M1_N0N1N2(ds_grid_desc_m_n[j]);
});
// TODO: on MI300 we could use NonTemporal load, MI200 streaming mode?
auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m0m1_n0n1n2[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
constexpr auto ds_thread_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DDataType, ScalarPerVector, true>{};
},
Number<NumDTensor>{});
auto aux_vgpr_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, EDataType, ScalarPerVector, true>{};
constexpr auto d_vgpr_buf_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1, Number<ScalarPerVector>{}));
// divide block work by [M, N, K]
const auto block_work_idx = block_2_etile_map.GetBottomIndex();
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n0_cluster_id = reduce_thread_cluster_idx[I1]; // Should be I0
const auto thread_n1_cluster_id = reduce_thread_cluster_idx[I2];
constexpr auto workspace_thread_desc_m0m1_n0n1n2 = MakeReductionThreadDesc_M0M1_N0N1N2();
auto ds_grid_load = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
using SliceLengths = Sequence<I1, I1, I1, I1, ScalarPerVector>;
return ThreadwiseTensorSliceTransfer_v2<DDataType,
DDataType,
decltype(ds_grid_desc_m0m1_n0n1n2(i)),
decltype(d_vgpr_buf_desc),
SliceLengths,
Sequence<0, 1, 2, 3, 4>,
4,
ScalarPerVector,
1,
false>{
ds_grid_desc_m_n(i),
make_multi_index(
block_work_idx[I0],
thread_m_cluster_id * workspace_thread_desc_m0m1_n0n1n2.GetLength(I1),
block_work_idx[I1],
thread_n0_cluster_id,
thread_n1_cluster_id * workspace_thread_desc_m0m1_n0n1n2.GetLength(I4))};
},
Number<NumDTensor>{});
auto e_grid_store =
ThreadwiseTensorSliceTransfer_v1r3<EDataType,
EDataType,
decltype(workspace_thread_desc_m0m1_n0n1n2),
decltype(e_grid_desc_m0m1_n0n1n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<I1, I1, I1, I1, ScalarPerVector>,
Sequence<0, 1, 2, 3, 4>,
4,
ScalarPerVector,
EGlobalMemoryDataOperation,
1,
false>{
e_grid_desc_m0m1_n0n1n2,
make_multi_index(
block_work_idx[I0],
thread_m_cluster_id * workspace_thread_desc_m0m1_n0n1n2.GetLength(I1),
block_work_idx[I1],
thread_n0_cluster_id,
thread_n1_cluster_id * workspace_thread_desc_m0m1_n0n1n2.GetLength(I4)),
ck::tensor_operation::element_wise::PassThrough{}};
constexpr auto MIter = workspace_thread_desc_m0m1_n0n1n2.GetLength(I1);
constexpr auto NIter = workspace_thread_desc_m0m1_n0n1n2.GetLength(I3);
constexpr auto n1_step = cluster_length_reduce.At(I2);
constexpr auto d_grid_M1_fwd_step = make_multi_index(I0, I1, I0, I0, I0);
constexpr auto d_grid_N1_fwd_step = make_multi_index(I0, I0, I0, n1_step, I0);
constexpr auto d_grid_N1_bwd_step =
make_multi_index(I0, I0, I0, -1 * n1_step * (NIter - 1), I0);
constexpr auto thr_buf_N1_offset = Number<ScalarPerVector>{};
constexpr auto thr_buf_M1_offset = NIter * thr_buf_N1_offset;
static_for<0, MIter, 1>{}([&](auto m_idx) {
static_for<0, NIter, 1>{}([&](auto n_idx) {
// load multiple Ds:
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).Run(ds_grid_desc_m0m1_n0n1n2(d_idx),
ds_grid_buf(d_idx),
d_vgpr_buf_desc,
make_tuple(I0, I0, I0, I0, I0),
ds_thread_buf(d_idx));
});
constexpr auto acc_buf_offset =
m_idx * thr_buf_M1_offset + n_idx * thr_buf_N1_offset;
// apply pointwise function
static_for<0, ScalarPerVector, 1>{}([&](auto I) {
// get reference to src data
const auto src_data_ds_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& { return ds_thread_buf[iSrc][I]; },
Number<NumDTensor>{});
const auto src_data_refs = concat_tuple_of_reference(
tie(acc_buff[acc_buf_offset + I]), src_data_ds_refs);
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(cde_element_op, tie(aux_vgpr_buf(I)), src_data_refs);
});
e_grid_store.Run(workspace_thread_desc_m0m1_n0n1n2,
make_tuple(I0, I0, I0, I0, I0),
aux_vgpr_buf,
e_grid_desc_m0m1_n0n1n2,
e_grid_buf);
if constexpr(n_idx != (NIter - 1))
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_N1_fwd_step);
});
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_N1_fwd_step);
}
else
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_N1_bwd_step);
});
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_N1_bwd_step);
}
}); // NIter
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).MoveSrcSliceWindow(ds_grid_desc_m0m1_n0n1n2(d_idx),
d_grid_M1_fwd_step);
});
e_grid_store.MoveDstSliceWindow(e_grid_desc_m0m1_n0n1n2, d_grid_M1_fwd_step);
}); // MIter
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment