Commit 821f6130 authored by Adam Osewski's avatar Adam Osewski
Browse files

Store partials from VGPR to GMEM

parent 51ae4aa2
......@@ -160,24 +160,7 @@ __global__ void
if(!b2c_tile_map.IsFirstKSplitBlock())
{
// Store partial results to auxilliary workspace.
// make results buffer tensor descriptor (registers).
// make workspace gmem tensor descriptor
// create ThreadGroupTransform and run copy
// if (threadIdx.x == 0)
// {
// // using CThreadBuffer = decltype(results_buffer);
// // constexpr index_t n_scalars = CThreadBuffer::s_per_buf.value;
// constexpr index_t n_scalars = 4;
// static_for<0, n_scalars, 1>{}([&](auto i) {
// printf("[kernel] bid: %d; c_thread_buff[%d]: %f\n",
// static_cast<index_t>(blockIdx.x),
// i.value,
// static_cast<float>(results_buffer[i]));
// });
// }
gridwise_gemm.StorePartials(p_workspace);
}
const index_t output_tile_idx =
......@@ -197,21 +180,6 @@ __global__ void
[[maybe_unused]] const index_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
// if(threadIdx.x == 0)
// {
// // using CThreadBuffer = decltype(results_buffer);
// // constexpr index_t n_scalars = CThreadBuffer::s_per_buf.value;
// constexpr index_t n_scalars = 4;
// static_for<0, n_scalars, 1>{}([&](auto i) {
// printf("[kernel] bid: %d; c_thread_buff[%d]: %f\n",
// static_cast<index_t>(blockIdx.x),
// i.value,
// static_cast<float>(results_buffer[i]));
// });
// }
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace.
......
......@@ -519,10 +519,32 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
Number<NumDTensor>{});
}
__host__ __device__ static auto
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(index_t grid_size)
{
const auto w_desc_grid_i1_mperb_nperb = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(
make_tuple(grid_size, I1.value, MPerBlock, NPerBlock),
make_tuple(MPerBlock * NPerBlock, MPerBlock * NPerBlock, NPerBlock, I1.value));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(
make_tuple(grid_size, I1.value, MPerBlock, NPerBlock),
make_tuple(MPerBlock * NPerBlock, MPerBlock * NPerBlock, I1.value, MPerBlock));
}
}();
return w_desc_grid_i1_mperb_nperb;
}
// TODO: we should refactor out all those common Make... descriptors to sth like
// gridwise_gemm_utils.hpp
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
__device__ __host__ static constexpr auto GetNPerBlock() { return NPerBlock; }
__device__ __host__ constexpr auto& GetCThreadBuffer()
{
......@@ -673,6 +695,135 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
block_2_etile_map);
}
__device__ void StorePartials(void* __restrict__ p_workspace)
{
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const auto workspace_grid_desc_m0_n0_m1_n1 =
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(get_grid_size());
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1.GetElementSpaceSize());
const auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// M0 = grid_size -> MRepeats
// N0 = 1 -> NRepeats
const auto workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1,
make_tuple(make_pass_through_transform(w_grid_m0),
make_pass_through_transform(w_grid_n0),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4, 6, 7, 8>{}, Sequence<3, 5, 9>{}));
const auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(make_merge_transform(make_tuple(w_grid_m0, M0)), // MRepeats (grid)
make_merge_transform(make_tuple(w_grid_n0, N0)), // NRepeats (grid)
make_pass_through_transform(M1), // MWave
make_pass_through_transform(N1), // NWave
make_pass_through_transform(M2), // mfma_instr.num_groups_per_blk
make_pass_through_transform(M3), // mfma_instr.num_input_blks
make_pass_through_transform(M4), // mfma_instr.group_size
make_pass_through_transform(N2)), // mfma_instr.num_threads_per_blk
make_tuple(Sequence<0, 2>{},
Sequence<1, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{},
Sequence<8>{},
Sequence<9>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
const auto c_thread_mtx_on_block =
blockwise_gemm_.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
auto c_thread_copy_vgpr_to_gmem = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
AccDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // DimAccessOrder
7, // DstVectorDim,
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(static_cast<index_t>(blockIdx.x),
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
c_thread_copy_vgpr_to_gmem.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
w_grid_buf);
}
// template <typename CThreadBufer,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// index_t NumDTensor_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment