Commit c1f7d9f2 authored by Adam Osewski's avatar Adam Osewski
Browse files

Accumulate partial results in workspace

parent fee53701
...@@ -177,9 +177,11 @@ __global__ void ...@@ -177,9 +177,11 @@ __global__ void
// Accumulate partial results. We can have different # of workgroups to reduce, thus we // Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value. // read actual flag value.
[[maybe_unused]] const index_t flag_v = __builtin_amdgcn_readfirstlane( const index_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset)); work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers) // TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
......
...@@ -4,19 +4,20 @@ ...@@ -4,19 +4,20 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck { namespace ck {
...@@ -807,7 +808,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -807,7 +808,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{// DstResetCoordinateAfterRun true>{// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(static_cast<index_t>(blockIdx.x), make_multi_index(static_cast<index_t>(blockIdx.x) * MXdlPerWave,
n_thread_data_on_block_idx[I0], n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1], n_thread_data_on_block_idx[I1],
...@@ -824,6 +825,162 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -824,6 +825,162 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
w_grid_buf); w_grid_buf);
} }
__device__ void AccumulatePartials(void* __restrict__ p_workspace, index_t reduce_count)
{
auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// using CThreadBufferT = ck::remove_reference_t<decltype(c_thread_buf)>;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(),
true>
acc_buf{};
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const auto workspace_grid_desc_m0_n0_m1_n1 =
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(get_grid_size());
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
// if (threadIdx.x == 0)
// {
// printf("w_grid_desc_m0_n0_m1_n1: [%d, %d, %d, %d]\n",
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I2),
// workspace_grid_desc_m0_n0_m1_n1.GetLength(I3));
// }
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// M0 = grid_size -> MRepeats
// N0 = 1 -> NRepeats
const auto workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1,
make_tuple(make_pass_through_transform(w_grid_m0),
make_pass_through_transform(w_grid_n0),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4, 6, 7, 8>{}, Sequence<3, 5, 9>{}));
const auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(make_merge_transform(make_tuple(w_grid_m0, M0)), // MRepeats (grid)
make_merge_transform(make_tuple(w_grid_n0, N0)), // NRepeats (grid)
make_pass_through_transform(M1), // MWave
make_pass_through_transform(N1), // NWave
make_pass_through_transform(M2), // mfma_instr.num_groups_per_blk
make_pass_through_transform(M3), // mfma_instr.num_input_blks
make_pass_through_transform(M4), // mfma_instr.group_size
make_pass_through_transform(N2)), // mfma_instr.num_threads_per_blk
make_tuple(Sequence<0, 2>{},
Sequence<1, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{},
Sequence<8>{},
Sequence<9>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2), // SrcDesc,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), // DstDesc,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLengths()), // SliceLengths,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // DimAccessOrder,
7, // SrcVectorDim,
1, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
make_multi_index((static_cast<index_t>(blockIdx.x) + 1) * MXdlPerWave,
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2])};
using Accumulation =
ck::detail::AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
for(int i_t = 1; i_t < reduce_count; ++i_t)
{
acc_buf.Clear();
acc_load.Run(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
w_grid_buf,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
acc_buf);
static_for<0, c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(), 1>{}(
[&](auto i_vec) { Accumulation::Calculate(c_thread_buf(i_vec), acc_buf[i_vec]); });
}
}
// template <typename CThreadBufer, // template <typename CThreadBufer,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation, // InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// index_t NumDTensor_, // index_t NumDTensor_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment