Commit 19d99b65 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add Cshuffle and results write to GMEM.

parent c1f7d9f2
......@@ -53,6 +53,7 @@ template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename DsDataType,
typename Block2ETileMapKSplit,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -124,12 +125,10 @@ __global__ void
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
// const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid);
const auto K = gemm_desc_ptr[group_id].K;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
// const auto StrideC = gemm_desc_ptr[group_id].StrideC;
auto gridwise_gemm = GridwiseGemm();
auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
......@@ -159,7 +158,6 @@ __global__ void
// if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock())
{
// Store partial results to auxilliary workspace.
gridwise_gemm.StorePartials(p_workspace);
}
......@@ -182,27 +180,33 @@ __global__ void
gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
// TODO do fusion, cshuffle and store results to GMEM
// gridwise_gemm.RunWrite(results_buffer,
// p_c_grid,
// M,
// N,
// K,
// StrideA,
// StrideB,
// StrideC,
// MPadded,
// NPadded,
// KPadded,
// K0,
// k_batch,
// static_cast<void*>(p_shared),
// b2c_tile_map);
const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
const auto stride_e = gemm_desc_ptr[group_id].StrideE;
const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
constexpr auto NumDTensor = DsDataType::Size();
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
gridwise_gemm.template RunWrite(p_ds_grid,
p_e_grid,
static_cast<void*>(p_shared),
M,
N,
stride_ds,
stride_e,
cde_element_op,
b2c_tile_map);
}
else
{
......@@ -303,6 +307,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
......@@ -687,6 +692,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
......@@ -819,6 +825,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
......@@ -861,6 +868,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
......
......@@ -44,6 +44,7 @@ template <typename ADataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -696,6 +697,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
block_2_etile_map);
}
// TODO Need to do CShuffle already here:
__device__ void StorePartials(void* __restrict__ p_workspace)
{
// M0 = grid_size
......@@ -849,15 +851,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
// if (threadIdx.x == 0)
// {
// printf("w_grid_desc_m0_n0_m1_n1: [%d, %d, %d, %d]\n",
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I2),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I3));
// }
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
......@@ -981,300 +974,285 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
}
// template <typename CThreadBufer,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// index_t NumDTensor_,
// typename DsDataType_,
// typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// typename CDEElementwiseOperation_,
// typename Block2ETileMap>
// __device__ void RunWrite(CThreadBufer c_thread_buf,
// const EDataType* __restrict__ p_workspace,
// DsGridPointer p_ds_grid,
// EDataType* __restrict__ p_e_grid,
// void* __restrict__ p_shared,
// const index_t KBatch,
// const CDEElementwiseOperation_& cde_element_op,
// const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
// ds_grid_desc_mblock_mperblock_nblock_nperblock,
// const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
// e_grid_desc_mblock_mperblock_nblock_nperblock,
// const Block2ETileMap& block_2_etile_map)
// {
// using DsGridDesc_M_N =
// remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
// DsGridDesc_M_N ds_grid_desc_m_n;
// const auto ds_grid_buf = generate_tuple(
// [&](auto i) {
// return make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_ds_grid[i],
// ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
// },
// Number<NumDTensor_>{});
// auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// static_for<0, NumDTensor, 1>{}([&](auto j) {
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
// ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[j]);
// });
// const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
// // using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
// // remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
// // DsGridDesc_M_N{}))>;
// // DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
// ds_grid_desc_mblock_mperblock_nblock_nperblock;
// // static_for<0, NumDTensor, 1>{}([&](auto j) {
// // ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
// // });
// // const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
// // shuffle C and write out
// static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
// NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
// "wrong!");
// constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// // TODO: hacky, fix it!
// constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// // TODO: hacky, fix it!
// // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
// blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
// constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
// constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
// constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
// constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
// constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
// constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
// constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
// GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
// auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// static_cast<CShuffleDataType*>(p_shared),
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
// make_tuple(
// make_freeze_transform(I0),
// make_unmerge_transform(make_tuple(
// Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
// M1, // M1 = MWave
// M2, // M2 * M3 * M4 = MPerXdl
// M3,
// M4)),
// make_freeze_transform(I0),
// make_unmerge_transform(make_tuple(
// Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
// N1, // N1 = NWave
// N2))), // N2 = NPerXdl
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(
// Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// // calculate origin of thread output tensor on global memory
// // blockwise GEMM c matrix starting index
// const auto c_thread_mtx_on_block =
// blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
// const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
// const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
// const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
// make_single_stage_tensor_adaptor(
// make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
// make_tuple(Sequence<0, 1, 2, 3, 4>{}),
// make_tuple(Sequence<0>{}));
// const auto m_thread_data_on_block_idx =
// m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
// make_multi_index(m_thread_data_on_block));
// const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
// make_single_stage_tensor_adaptor(
// make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
// make_tuple(Sequence<0, 1, 2>{}),
// make_tuple(Sequence<0>{}));
// const auto n_thread_data_on_block_idx =
// n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
// make_multi_index(n_thread_data_on_block));
// // shuffle: threadwise copy C from VGPR to LDS
// auto c_thread_copy_vgpr_to_lds =
// ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
// CShuffleDataType,
// decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// ck::tensor_operation::element_wise::PassThrough,
// Sequence<CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle,
// I1,
// I1,
// M2,
// I1,
// M4,
// I1>,
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// 7,
// 1,
// InMemoryDataOperationEnum::Set,
// 1,
// true>{
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// make_multi_index(0,
// 0,
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]),
// ck::tensor_operation::element_wise::PassThrough{}};
// // tuple of reference to C/Ds tensor descriptors
// const auto c_ds_desc_refs = concat_tuple_of_reference(
// tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
// generate_tie(
// [&](auto i) -> const auto& // return type should be reference
// { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
// Number<NumDTensor_>{}));
// // tuple of reference to C/Ds tensor descriptors
// const auto c_ds_buf_refs = concat_tuple_of_reference(
// tie(c_shuffle_block_buf),
// generate_tie(
// [&](auto i) -> const auto& // return type should be reference
// { return ds_grid_buf[i]; },
// Number<NumDTensor_>{}));
// // tuple of starting index of C/Ds blockwise copy
// const auto idx_c_ds_block_begin = container_concat(
// make_tuple(make_multi_index(0, 0, 0, 0)),
// generate_tuple(
// [&](auto) {
// return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0);
// },
// Number<NumDTensor_>{}));
// // space filling curve for threadwise C in VGPR before shuffle
// constexpr auto sfc_c_vgpr =
// SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// Sequence<CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle,
// 1,
// 1,
// M2,
// 1,
// M4,
// 1>>{};
// // space filling curve for shuffled blockwise C/D/E
// constexpr auto sfc_cde_block =
// SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
// Sequence<0, 2, 1, 3>,
// Sequence<1,
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// 1,
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
// constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// // blockwise copy C/D/E between LDS and global
// auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
// ThisThreadBlock,
// decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
// Tuple<EDataType>,
// decltype(c_ds_desc_refs),
// decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
// CDEElementwiseOperation_,
// Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// // Sequence support
// // arbitray type
// Sequence<1,
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// 1,
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
// CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
// Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
// Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
// 3, // index_t VectorDim,
// CDEShuffleBlockTransferScalarPerVector_NPerBlock,
// sequence_merge_t<
// Sequence<true>,
// uniform_sequence_gen_t<NumDTensor_,
// false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
// Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
// {c_ds_desc_refs,
// idx_c_ds_block_begin,
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
// cde_element_op};
// static_for<0, num_access, 1>{}([&](auto access_id) {
// // make sure it's safe to write to LDS
// block_sync_lds();
// // each thread write its data from VGPR to LDS
// c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
// c_thread_buf,
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// c_shuffle_block_buf);
// // make sure it's safe to read from LDS
// block_sync_lds();
// // each block copy its data from LDS to global
// cde_block_copy_lds_and_global.Run(
// c_ds_desc_refs,
// c_ds_buf_refs,
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// tie(e_grid_buf));
// if constexpr(access_id < num_access - 1)
// {
// constexpr auto cde_lds_and_global_step =
// sfc_cde_block.GetForwardStep(access_id);
// // move on Ds
// static_for<0, NumDTensor_, 1>{}([&](auto i) {
// cde_block_copy_lds_and_global.MoveSrcSliceWindow(
// c_ds_desc_refs, i + I1, cde_lds_and_global_step);
// });
// // move on E
// cde_block_copy_lds_and_global.MoveDstSliceWindow(
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// I0,
// cde_lds_and_global_step);
// }
// });
// }
template <typename Block2ETileMap>
__device__ void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const index_t M,
const index_t N,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const CDEElementwiseOperation& cde_element_op,
const Block2ETileMap& block_2_etile_map)
{
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
DsGridDesc_M_N ds_grid_desc_m_n;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[j]);
});
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
});
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// shuffle C and write out
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// divide block work by [M, N, K]
const auto block_work_idx = block_2_etile_map.GetBottomIndex();
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step = sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment