Commit 1c5b049d authored by Adam Osewski's avatar Adam Osewski
Browse files

Add Cshuffle and results write to GMEM.

parent 92eb966d
...@@ -53,6 +53,7 @@ template <typename GridwiseGemm, ...@@ -53,6 +53,7 @@ template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename DsDataType,
typename Block2ETileMapKSplit, typename Block2ETileMapKSplit,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -124,12 +125,10 @@ __global__ void ...@@ -124,12 +125,10 @@ __global__ void
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid); const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid); const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
// const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid);
const auto K = gemm_desc_ptr[group_id].K; const auto K = gemm_desc_ptr[group_id].K;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
// const auto StrideC = gemm_desc_ptr[group_id].StrideC;
auto gridwise_gemm = GridwiseGemm(); auto gridwise_gemm = GridwiseGemm();
auto& results_buffer = gridwise_gemm.GetCThreadBuffer(); auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
...@@ -159,7 +158,6 @@ __global__ void ...@@ -159,7 +158,6 @@ __global__ void
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock()) if(!b2c_tile_map.IsFirstKSplitBlock())
{ {
// Store partial results to auxilliary workspace.
gridwise_gemm.StorePartials(p_workspace); gridwise_gemm.StorePartials(p_workspace);
} }
...@@ -182,27 +180,33 @@ __global__ void ...@@ -182,27 +180,33 @@ __global__ void
gridwise_gemm.AccumulatePartials(p_workspace, flag_v); gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
// TODO do fusion, cshuffle and store results to GMEM const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
// gridwise_gemm.RunWrite(results_buffer, const auto stride_e = gemm_desc_ptr[group_id].StrideE;
// p_c_grid, const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
// M,
// N, constexpr auto NumDTensor = DsDataType::Size();
// K, using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
// StrideA,
// StrideB, DsGridPointer p_ds_grid;
// StrideC,
// MPadded, static_for<0, NumDTensor, 1>{}([&](auto i) {
// NPadded, using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// KPadded, // D pointer
// K0, p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
// k_batch, });
// static_cast<void*>(p_shared),
// b2c_tile_map); gridwise_gemm.template RunWrite(p_ds_grid,
p_e_grid,
static_cast<void*>(p_shared),
M,
N,
stride_ds,
stride_e,
cde_element_op,
b2c_tile_map);
} }
else else
{ {
...@@ -303,6 +307,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -303,6 +307,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
GemmSpec, GemmSpec,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -687,6 +692,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -687,6 +692,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType, ADataType,
BDataType, BDataType,
EDataType, EDataType,
DsDataType,
Block2ETileMapKSplit, Block2ETileMapKSplit,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -819,6 +825,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -819,6 +825,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType, ADataType,
BDataType, BDataType,
EDataType, EDataType,
DsDataType,
Block2ETileMapKSplit, Block2ETileMapKSplit,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -861,6 +868,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -861,6 +868,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType, ADataType,
BDataType, BDataType,
EDataType, EDataType,
DsDataType,
Block2ETileMapKSplit, Block2ETileMapKSplit,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment