"example/vscode:/vscode.git/clone" did not exist on "95a83c6ebfaa8d2a07aae9ebf4922879ecfcd630"
Commit 98def248 authored by Adam Osewski's avatar Adam Osewski
Browse files

Rework RunWrite.

parent bbd26e10
...@@ -157,10 +157,12 @@ __global__ void ...@@ -157,10 +157,12 @@ __global__ void
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock()) // With cshuffle at store partials all workgroups have to store
{ // their partials to workspace gmem.
GridwiseGemm::StorePartials(p_workspace, results_buffer); // TODO: The reduction workgroup don't have to store it's own results to GMEM!
} // Would be enough to keep it in registers and during AccumulatePartials
// do CShuffle in flight with loading partials products of other peer workgroups.
GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer);
work_scheduler.FlagFinished(); work_scheduler.FlagFinished();
...@@ -171,10 +173,20 @@ __global__ void ...@@ -171,10 +173,20 @@ __global__ void
index_t neighbour_count = index_t neighbour_count =
work_scheduler.WaitForNeighbours(k_batch, b2c_tile_map.GetTileKIdx()); work_scheduler.WaitForNeighbours(k_batch, b2c_tile_map.GetTileKIdx());
constexpr auto workspace_thread_desc_m0m1_n0n1n2 =
GridwiseGemm::MakeReductionThreadDesc_M0M1_N0N1N2();
StaticBuffer<AddressSpaceEnum::Vgpr,
typename GridwiseGemm::CShuffleDataT,
workspace_thread_desc_m0m1_n0n1n2.GetElementSpaceSize(),
true>
acc_buff{};
acc_buff.Clear();
// Accumulate only when there is at least two workgroups processing splitk data-tiles // Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile. // across same MN-output tile.
if(neighbour_count > 0) if(neighbour_count > 0)
GridwiseGemm::AccumulatePartials(p_workspace, results_buffer, neighbour_count + 1); GridwiseGemm::AccumulatePartials(p_workspace, acc_buff, neighbour_count + 1);
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(neighbour_count); work_scheduler.Reset(neighbour_count);
...@@ -195,17 +207,17 @@ __global__ void ...@@ -195,17 +207,17 @@ __global__ void
GridwiseGemm::template RunWrite(p_ds_grid, GridwiseGemm::template RunWrite(p_ds_grid,
p_e_grid, p_e_grid,
static_cast<void*>(p_shared), acc_buff,
M, M,
N, N,
stride_ds, stride_ds,
stride_e, stride_e,
cde_element_op, cde_element_op,
b2c_tile_map, b2c_tile_map);
results_buffer);
} }
else if(work_scheduler.HasTile()) else if(work_scheduler.HasTile())
{ {
// TODO Move this just before StorePartials!
work_scheduler.WaitForReduction(); work_scheduler.WaitForReduction();
} }
} while(work_scheduler.HasTile()); } while(work_scheduler.HasTile());
...@@ -757,7 +769,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -757,7 +769,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<< ", grid_size: " << grid_size << ", flag_count: " << flag_count << ", grid_size: " << grid_size << ", flag_count: " << flag_count
<< ", p_flags: " << p_flags << ", workspace_ptr: " << dev_gemm_workspace << ", p_flags: " << p_flags << ", workspace_ptr: " << dev_gemm_workspace
<< ", acc_workspace_size_bytes: " << acc_workspace_size_bytes << ", acc_workspace_size_bytes: " << acc_workspace_size_bytes
<< std::endl; << ", kbatch: " << arg.K_BATCH << std::endl;
} }
auto preprocess = [&]() { auto preprocess = [&]() {
...@@ -995,7 +1007,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -995,7 +1007,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// the amount of workspace bytes needed, may be less due to the number of available CUs in // the amount of workspace bytes needed, may be less due to the number of available CUs in
// stream used to launch kernel. // stream used to launch kernel.
size_t size_bytes = size_t size_bytes =
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(AccDataType), grid_size) + Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) +
flag_count * sizeof(uint32_t); flag_count * sizeof(uint32_t);
return size_bytes; return size_bytes;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment