"tests/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "8efe23f15087222540ec076ed00785544442c02f"
Commit 98def248 authored by Adam Osewski's avatar Adam Osewski
Browse files

Rework RunWrite.

parent bbd26e10
......@@ -157,10 +157,12 @@ __global__ void
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock())
{
GridwiseGemm::StorePartials(p_workspace, results_buffer);
}
// With cshuffle at store partials all workgroups have to store
// their partials to workspace gmem.
// TODO: The reduction workgroup don't have to store it's own results to GMEM!
// Would be enough to keep it in registers and during AccumulatePartials
// do CShuffle in flight with loading partials products of other peer workgroups.
GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer);
work_scheduler.FlagFinished();
......@@ -171,10 +173,20 @@ __global__ void
index_t neighbour_count =
work_scheduler.WaitForNeighbours(k_batch, b2c_tile_map.GetTileKIdx());
constexpr auto workspace_thread_desc_m0m1_n0n1n2 =
GridwiseGemm::MakeReductionThreadDesc_M0M1_N0N1N2();
StaticBuffer<AddressSpaceEnum::Vgpr,
typename GridwiseGemm::CShuffleDataT,
workspace_thread_desc_m0m1_n0n1n2.GetElementSpaceSize(),
true>
acc_buff{};
acc_buff.Clear();
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if(neighbour_count > 0)
GridwiseGemm::AccumulatePartials(p_workspace, results_buffer, neighbour_count + 1);
GridwiseGemm::AccumulatePartials(p_workspace, acc_buff, neighbour_count + 1);
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(neighbour_count);
......@@ -195,17 +207,17 @@ __global__ void
GridwiseGemm::template RunWrite(p_ds_grid,
p_e_grid,
static_cast<void*>(p_shared),
acc_buff,
M,
N,
stride_ds,
stride_e,
cde_element_op,
b2c_tile_map,
results_buffer);
b2c_tile_map);
}
else if(work_scheduler.HasTile())
{
// TODO Move this just before StorePartials!
work_scheduler.WaitForReduction();
}
} while(work_scheduler.HasTile());
......@@ -757,7 +769,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<< ", grid_size: " << grid_size << ", flag_count: " << flag_count
<< ", p_flags: " << p_flags << ", workspace_ptr: " << dev_gemm_workspace
<< ", acc_workspace_size_bytes: " << acc_workspace_size_bytes
<< std::endl;
<< ", kbatch: " << arg.K_BATCH << std::endl;
}
auto preprocess = [&]() {
......@@ -995,7 +1007,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// the amount of workspace bytes needed, may be less due to the number of available CUs in
// stream used to launch kernel.
size_t size_bytes =
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(AccDataType), grid_size) +
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) +
flag_count * sizeof(uint32_t);
return size_bytes;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment