Commit 1c5b049d authored by Adam Osewski's avatar Adam Osewski
Browse files

Add Cshuffle and results write to GMEM.

parent 92eb966d
......@@ -53,6 +53,7 @@ template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename DsDataType,
typename Block2ETileMapKSplit,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -124,12 +125,10 @@ __global__ void
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
// const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid);
const auto K = gemm_desc_ptr[group_id].K;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
// const auto StrideC = gemm_desc_ptr[group_id].StrideC;
auto gridwise_gemm = GridwiseGemm();
auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
......@@ -159,7 +158,6 @@ __global__ void
// if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock())
{
// Store partial results to auxilliary workspace.
gridwise_gemm.StorePartials(p_workspace);
}
......@@ -182,27 +180,33 @@ __global__ void
gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
// TODO do fusion, cshuffle and store results to GMEM
// gridwise_gemm.RunWrite(results_buffer,
// p_c_grid,
// M,
// N,
// K,
// StrideA,
// StrideB,
// StrideC,
// MPadded,
// NPadded,
// KPadded,
// K0,
// k_batch,
// static_cast<void*>(p_shared),
// b2c_tile_map);
const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
const auto stride_e = gemm_desc_ptr[group_id].StrideE;
const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
constexpr auto NumDTensor = DsDataType::Size();
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
gridwise_gemm.template RunWrite(p_ds_grid,
p_e_grid,
static_cast<void*>(p_shared),
M,
N,
stride_ds,
stride_e,
cde_element_op,
b2c_tile_map);
}
else
{
......@@ -303,6 +307,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
......@@ -687,6 +692,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
......@@ -819,6 +825,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
......@@ -861,6 +868,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment