Commit 19d99b65 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add Cshuffle and results write to GMEM.

parent c1f7d9f2
...@@ -53,6 +53,7 @@ template <typename GridwiseGemm, ...@@ -53,6 +53,7 @@ template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename DsDataType,
typename Block2ETileMapKSplit, typename Block2ETileMapKSplit,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -124,12 +125,10 @@ __global__ void ...@@ -124,12 +125,10 @@ __global__ void
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid); const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid); const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
// const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid);
const auto K = gemm_desc_ptr[group_id].K; const auto K = gemm_desc_ptr[group_id].K;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
// const auto StrideC = gemm_desc_ptr[group_id].StrideC;
auto gridwise_gemm = GridwiseGemm(); auto gridwise_gemm = GridwiseGemm();
auto& results_buffer = gridwise_gemm.GetCThreadBuffer(); auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
...@@ -159,7 +158,6 @@ __global__ void ...@@ -159,7 +158,6 @@ __global__ void
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock()) if(!b2c_tile_map.IsFirstKSplitBlock())
{ {
// Store partial results to auxilliary workspace.
gridwise_gemm.StorePartials(p_workspace); gridwise_gemm.StorePartials(p_workspace);
} }
...@@ -182,27 +180,33 @@ __global__ void ...@@ -182,27 +180,33 @@ __global__ void
gridwise_gemm.AccumulatePartials(p_workspace, flag_v); gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
// TODO do fusion, cshuffle and store results to GMEM const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
// gridwise_gemm.RunWrite(results_buffer, const auto stride_e = gemm_desc_ptr[group_id].StrideE;
// p_c_grid, const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
// M,
// N, constexpr auto NumDTensor = DsDataType::Size();
// K, using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
// StrideA,
// StrideB, DsGridPointer p_ds_grid;
// StrideC,
// MPadded, static_for<0, NumDTensor, 1>{}([&](auto i) {
// NPadded, using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// KPadded, // D pointer
// K0, p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
// k_batch, });
// static_cast<void*>(p_shared),
// b2c_tile_map); gridwise_gemm.template RunWrite(p_ds_grid,
p_e_grid,
static_cast<void*>(p_shared),
M,
N,
stride_ds,
stride_e,
cde_element_op,
b2c_tile_map);
} }
else else
{ {
...@@ -303,6 +307,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -303,6 +307,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
GemmSpec, GemmSpec,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -687,6 +692,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -687,6 +692,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType, ADataType,
BDataType, BDataType,
EDataType, EDataType,
DsDataType,
Block2ETileMapKSplit, Block2ETileMapKSplit,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -819,6 +825,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -819,6 +825,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType, ADataType,
BDataType, BDataType,
EDataType, EDataType,
DsDataType,
Block2ETileMapKSplit, Block2ETileMapKSplit,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -861,6 +868,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -861,6 +868,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType, ADataType,
BDataType, BDataType,
EDataType, EDataType,
DsDataType,
Block2ETileMapKSplit, Block2ETileMapKSplit,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
......
...@@ -44,6 +44,7 @@ template <typename ADataType, ...@@ -44,6 +44,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
tensor_operation::device::GemmSpecialization GemmSpec, tensor_operation::device::GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -696,6 +697,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -696,6 +697,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
block_2_etile_map); block_2_etile_map);
} }
// TODO Need to do CShuffle already here:
__device__ void StorePartials(void* __restrict__ p_workspace) __device__ void StorePartials(void* __restrict__ p_workspace)
{ {
// M0 = grid_size // M0 = grid_size
...@@ -849,15 +851,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -849,15 +851,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0); const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1); const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
// if (threadIdx.x == 0)
// {
// printf("w_grid_desc_m0_n0_m1_n1: [%d, %d, %d, %d]\n",
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I2),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I3));
// }
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -981,300 +974,285 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -981,300 +974,285 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
} }
// template <typename CThreadBufer, template <typename Block2ETileMap>
// InMemoryDataOperationEnum EGlobalMemoryDataOperation, __device__ void RunWrite(DsGridPointer p_ds_grid,
// index_t NumDTensor_, EDataType* __restrict__ p_e_grid,
// typename DsDataType_, void* __restrict__ p_shared,
// typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, const index_t M,
// typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, const index_t N,
// typename CDEElementwiseOperation_, const std::array<index_t, NumDTensor> StrideDs,
// typename Block2ETileMap> const index_t StrideE,
// __device__ void RunWrite(CThreadBufer c_thread_buf, const CDEElementwiseOperation& cde_element_op,
// const EDataType* __restrict__ p_workspace, const Block2ETileMap& block_2_etile_map)
// DsGridPointer p_ds_grid, {
// EDataType* __restrict__ p_e_grid, using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// void* __restrict__ p_shared,
// const index_t KBatch, using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
// const CDEElementwiseOperation_& cde_element_op, remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
// const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& DsGridDesc_M_N{}))>;
// ds_grid_desc_mblock_mperblock_nblock_nperblock,
// const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& DsGridDesc_M_N ds_grid_desc_m_n;
// e_grid_desc_mblock_mperblock_nblock_nperblock, DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
// const Block2ETileMap& block_2_etile_map)
// { static_for<0, NumDTensor, 1>{}([&](auto j) {
// using DsGridDesc_M_N = using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
// remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[j]);
// DsGridDesc_M_N ds_grid_desc_m_n; });
// const auto ds_grid_buf = generate_tuple( static_for<0, NumDTensor, 1>{}([&](auto j) {
// [&](auto i) { ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
// return make_dynamic_buffer<AddressSpaceEnum::Global>( MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
// p_ds_grid[i], });
// ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
// }, const auto ds_grid_buf = generate_tuple(
// Number<NumDTensor_>{}); [&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
// auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( p_ds_grid[i],
// p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
// static_for<0, NumDTensor, 1>{}([&](auto j) { Number<NumDTensor>{});
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
// ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[j]); const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
// }); MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// // using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = const auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// // remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
// // DsGridDesc_M_N{}))>; // shuffle C and write out
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
// // DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
// ds_grid_desc_mblock_mperblock_nblock_nperblock; "wrong!");
// // static_for<0, NumDTensor, 1>{}([&](auto j) { constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// // ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
// // }); // divide block work by [M, N, K]
const auto block_work_idx = block_2_etile_map.GetBottomIndex();
// // const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
// // shuffle C and write out BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
// NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, // TODO: hacky, fix it!
// "wrong!"); // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
// constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
// // TODO: hacky, fix it! constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
// constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
// // TODO: hacky, fix it! constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
// // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
// constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
// constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
// constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); static_cast<CShuffleDataType*>(p_shared),
// constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
// constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
// constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_freeze_transform(I0),
// constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = make_unmerge_transform(make_tuple(
// GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
// auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( M2, // M2 * M3 * M4 = MPerXdl
// static_cast<CShuffleDataType*>(p_shared), M3,
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); M4)),
make_freeze_transform(I0),
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( make_unmerge_transform(make_tuple(
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
// make_tuple( N1, // N1 = NWave
// make_freeze_transform(I0), N2))), // N2 = NPerXdl
// make_unmerge_transform(make_tuple( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle make_tuple(Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// M1, // M1 = MWave
// M2, // M2 * M3 * M4 = MPerXdl // calculate origin of thread output tensor on global memory
// M3, // blockwise GEMM c matrix starting index
// M4)), const auto c_thread_mtx_on_block =
// make_freeze_transform(I0), blockwise_gemm_.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
// make_unmerge_transform(make_tuple(
// Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
// N1, // N1 = NWave const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
// N2))), // N2 = NPerXdl
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
// make_tuple( make_single_stage_tensor_adaptor(
// Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
// // calculate origin of thread output tensor on global memory make_tuple(Sequence<0>{}));
// // blockwise GEMM c matrix starting index
// const auto c_thread_mtx_on_block = const auto m_thread_data_on_block_idx =
// blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
// const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
// const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
// const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = make_tuple(Sequence<0, 1, 2>{}),
// make_single_stage_tensor_adaptor( make_tuple(Sequence<0>{}));
// make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
// make_tuple(Sequence<0, 1, 2, 3, 4>{}), const auto n_thread_data_on_block_idx =
// make_tuple(Sequence<0>{})); n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// const auto m_thread_data_on_block_idx =
// m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( // shuffle: threadwise copy C from VGPR to LDS
// make_multi_index(m_thread_data_on_block)); auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
// const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = CShuffleDataType,
// make_single_stage_tensor_adaptor( decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// make_tuple(Sequence<0, 1, 2>{}), ck::tensor_operation::element_wise::PassThrough,
// make_tuple(Sequence<0>{})); Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
// const auto n_thread_data_on_block_idx = I1,
// n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( I1,
// make_multi_index(n_thread_data_on_block)); M2,
I1,
// // shuffle: threadwise copy C from VGPR to LDS M4,
// auto c_thread_copy_vgpr_to_lds = I1>,
// ThreadwiseTensorSliceTransfer_v1r3<AccDataType, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// CShuffleDataType, 7,
// decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), 1,
// decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), InMemoryDataOperationEnum::Set,
// ck::tensor_operation::element_wise::PassThrough, 1,
// Sequence<CShuffleMXdlPerWavePerShuffle, true>{
// CShuffleNXdlPerWavePerShuffle, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// I1, make_multi_index(0,
// I1, 0,
// M2, m_thread_data_on_block_idx[I1],
// I1, n_thread_data_on_block_idx[I1],
// M4, m_thread_data_on_block_idx[I2],
// I1>, m_thread_data_on_block_idx[I3],
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>, m_thread_data_on_block_idx[I4],
// 7, n_thread_data_on_block_idx[I2]),
// 1, ck::tensor_operation::element_wise::PassThrough{}};
// InMemoryDataOperationEnum::Set,
// 1, // tuple of reference to C/Ds tensor descriptors
// true>{ const auto c_ds_desc_refs = concat_tuple_of_reference(
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
// make_multi_index(0, generate_tie(
// 0, [&](auto i) -> const auto& // return type should be reference
// m_thread_data_on_block_idx[I1], { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
// n_thread_data_on_block_idx[I1], Number<NumDTensor>{}));
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3], // tuple of reference to C/Ds tensor descriptors
// m_thread_data_on_block_idx[I4], const auto c_ds_buf_refs = concat_tuple_of_reference(
// n_thread_data_on_block_idx[I2]), tie(c_shuffle_block_buf),
// ck::tensor_operation::element_wise::PassThrough{}}; generate_tie(
[&](auto i) -> const auto& // return type should be reference
// // tuple of reference to C/Ds tensor descriptors { return ds_grid_buf[i]; },
// const auto c_ds_desc_refs = concat_tuple_of_reference( Number<NumDTensor>{}));
// tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
// generate_tie( // tuple of starting index of C/Ds blockwise copy
// [&](auto i) -> const auto& // return type should be reference const auto idx_c_ds_block_begin = container_concat(
// { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, make_tuple(make_multi_index(0, 0, 0, 0)),
// Number<NumDTensor_>{})); generate_tuple(
[&](auto) {
// // tuple of reference to C/Ds tensor descriptors return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
// const auto c_ds_buf_refs = concat_tuple_of_reference( },
// tie(c_shuffle_block_buf), Number<NumDTensor>{}));
// generate_tie(
// [&](auto i) -> const auto& // return type should be reference // space filling curve for threadwise C in VGPR before shuffle
// { return ds_grid_buf[i]; }, constexpr auto sfc_c_vgpr =
// Number<NumDTensor_>{})); SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// // tuple of starting index of C/Ds blockwise copy Sequence<CShuffleMXdlPerWavePerShuffle,
// const auto idx_c_ds_block_begin = container_concat( CShuffleNXdlPerWavePerShuffle,
// make_tuple(make_multi_index(0, 0, 0, 0)), 1,
// generate_tuple( 1,
// [&](auto) { M2,
// return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); 1,
// }, M4,
// Number<NumDTensor_>{})); 1>>{};
// // space filling curve for threadwise C in VGPR before shuffle // space filling curve for shuffled blockwise C/D/E
// constexpr auto sfc_c_vgpr = constexpr auto sfc_cde_block =
// SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 2, 1, 3>,
// Sequence<CShuffleMXdlPerWavePerShuffle, Sequence<1,
// CShuffleNXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// 1, 1,
// 1, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
// M2,
// 1, constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// M4,
// 1>>{}; static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// // space filling curve for shuffled blockwise C/D/E // blockwise copy C/D/E between LDS and global
// constexpr auto sfc_cde_block = auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
// SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>, ThisThreadBlock,
// Sequence<0, 2, 1, 3>, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
// Sequence<1, Tuple<EDataType>,
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, decltype(c_ds_desc_refs),
// 1, decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); // Sequence support
// arbitray type
// static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// // blockwise copy C/D/E between LDS and global 1,
// auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
// ThisThreadBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
// decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
// Tuple<EDataType>, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
// decltype(c_ds_desc_refs), 3, // index_t VectorDim,
// decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), CDEShuffleBlockTransferScalarPerVector_NPerBlock,
// CDEElementwiseOperation_, sequence_merge_t<
// Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence<true>,
// // Sequence support uniform_sequence_gen_t<NumDTensor,
// // arbitray type false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
// Sequence<1, Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, {c_ds_desc_refs,
// 1, idx_c_ds_block_begin,
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
// Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, cde_element_op};
// Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
// 3, // index_t VectorDim, static_for<0, num_access, 1>{}([&](auto access_id) {
// CDEShuffleBlockTransferScalarPerVector_NPerBlock, // make sure it's safe to write to LDS
// sequence_merge_t< block_sync_lds();
// Sequence<true>,
// uniform_sequence_gen_t<NumDTensor_, // each thread write its data from VGPR to LDS
// false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
// {c_ds_desc_refs, c_thread_buf,
// idx_c_ds_block_begin, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock), c_shuffle_block_buf);
// make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
// cde_element_op}; // make sure it's safe to read from LDS
block_sync_lds();
// static_for<0, num_access, 1>{}([&](auto access_id) {
// // make sure it's safe to write to LDS // each block copy its data from LDS to global
// block_sync_lds(); cde_block_copy_lds_and_global.Run(c_ds_desc_refs,
c_ds_buf_refs,
// // each thread write its data from VGPR to LDS tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, tie(e_grid_buf));
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
// c_thread_buf, if constexpr(access_id < num_access - 1)
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, {
// c_shuffle_block_buf); constexpr auto cde_lds_and_global_step = sfc_cde_block.GetForwardStep(access_id);
// // make sure it's safe to read from LDS // move on Ds
// block_sync_lds(); static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
// // each block copy its data from LDS to global c_ds_desc_refs, i + I1, cde_lds_and_global_step);
// cde_block_copy_lds_and_global.Run( });
// c_ds_desc_refs,
// c_ds_buf_refs, // move on E
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_block_copy_lds_and_global.MoveDstSliceWindow(
// tie(e_grid_buf)); tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
// if constexpr(access_id < num_access - 1) cde_lds_and_global_step);
// { }
// constexpr auto cde_lds_and_global_step = });
// sfc_cde_block.GetForwardStep(access_id); }
// // move on Ds
// static_for<0, NumDTensor_, 1>{}([&](auto i) {
// cde_block_copy_lds_and_global.MoveSrcSliceWindow(
// c_ds_desc_refs, i + I1, cde_lds_and_global_step);
// });
// // move on E
// cde_block_copy_lds_and_global.MoveDstSliceWindow(
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// I0,
// cde_lds_and_global_step);
// }
// });
// }
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment