Commit 5b2eb1ab authored by Adam Osewski's avatar Adam Osewski
Browse files

Single flag per workgroup synchronization scheme.

parent f8ca9048
...@@ -156,32 +156,28 @@ __global__ void ...@@ -156,32 +156,28 @@ __global__ void
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
const index_t output_tile_idx_offset = __builtin_amdgcn_readfirstlane(offset / k_batch);
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock()) if(!b2c_tile_map.IsFirstKSplitBlock())
{ {
GridwiseGemm::StorePartials(p_workspace, results_buffer); GridwiseGemm::StorePartials(p_workspace, results_buffer);
} }
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.FlagFinished();
// The workgroup which processed first K tile accumulates results and stores to GMEM // The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock()) if(b2c_tile_map.IsFirstKSplitBlock())
{ {
// Wait untill all other blocks for this [M,N] tile store their results. // Wait untill all other blocks for this [M,N] tile store their results.
index_t neighbour_count = work_scheduler.WaitForNeighbours( index_t neighbour_count =
k_batch, b2c_tile_map.GetTileKIdx(), output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForNeighbours(k_batch, b2c_tile_map.GetTileKIdx());
// Accumulate only when there is at least two workgroups processing splitk data-tiles // Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile. // across same MN-output tile.
if(neighbour_count > 1) if(neighbour_count > 0)
GridwiseGemm::AccumulatePartials(p_workspace, results_buffer, neighbour_count); GridwiseGemm::AccumulatePartials(p_workspace, results_buffer, neighbour_count + 1);
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(neighbour_count);
const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid); const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
const auto stride_e = gemm_desc_ptr[group_id].StrideE; const auto stride_e = gemm_desc_ptr[group_id].StrideE;
...@@ -210,7 +206,7 @@ __global__ void ...@@ -210,7 +206,7 @@ __global__ void
} }
else if(work_scheduler.HasTile()) else if(work_scheduler.HasTile())
{ {
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForReduction();
} }
} while(work_scheduler.HasTile()); } while(work_scheduler.HasTile());
#else #else
...@@ -752,7 +748,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -752,7 +748,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) + void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize( Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size); sizeof(typename GridwiseGemm::AccType), grid_size);
std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH; std::size_t flag_count = grid_size;
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
...@@ -993,7 +989,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -993,7 +989,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
{ {
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block; grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
} }
int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH; int flag_count = grid_size;
// This would be the maximum needed workspace size. Since actual grid size, which determines // This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in // the amount of workspace bytes needed, may be less due to the number of available CUs in
......
...@@ -32,8 +32,7 @@ enum struct WorkSchedulingPolicy ...@@ -32,8 +32,7 @@ enum struct WorkSchedulingPolicy
class StridedReductionTileLoop class StridedReductionTileLoop
{ {
public: public:
__device__ StridedReductionTileLoop(index_t tile_count, __device__ StridedReductionTileLoop(index_t tile_count, uint32_t* const __restrict__ p_flags)
volatile uint32_t* const __restrict__ p_flags)
: tile_count_{tile_count}, : tile_count_{tile_count},
tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()}, tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
tile_id_{get_block_1d_id() * tiles_per_block_}, tile_id_{get_block_1d_id() * tiles_per_block_},
...@@ -54,62 +53,29 @@ class StridedReductionTileLoop ...@@ -54,62 +53,29 @@ class StridedReductionTileLoop
return HasTile(); return HasTile();
} }
__device__ index_t GetFlagCount(index_t k_tiles) const __device__ index_t GetFlagCount() const { return get_grid_size(); }
{
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
return (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
}
/// ///
/// @brief Calculate this workgroup flag index. /// @brief Get this workgroup flag index.
///
/// @note Note this scheduler intentionaly does not have flag index as its member, since
/// current workgroup may process tiles across different MN-output tiles or
/// acorss different GEMMs (grouped gemm).
///
/// @param[in] k_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) linear tile index (of current GEMM).
/// @param[in] output_tile_idx_offset The accumulated offset of output tiles from previous
/// GEMMs.
/// ///
/// @return The workgroup flag index. /// @return The workgroup flag index.
/// ///
__device__ uint32_t GetWorkgroupFlagIdx(index_t k_tiles, __device__ uint32_t GetWorkgroupFlagIdx() const { return static_cast<uint32_t>(blockIdx.x); }
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{
return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
}
/// ///
/// @brief Flag each workgroup that has finished its work. /// @brief Flag each workgroup that has finished its work.
/// ///
/// @param[in] k_tiles The number of tiles in the reduced dimension. __device__ void FlagFinished() { finished_block_flags_.inc(GetWorkgroupFlagIdx()); }
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
__device__ void
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
finished_block_flags_.inc(fidx);
}
/// ///
/// @brief Wait until each workgroup has finished its work. /// @brief Wait until each workgroup has finished its work.
/// ///
/// @param[in] k_tiles The number of tiles in the reduced dimension. /// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tile_idx The currently processed tile k index. /// @param[in] k_tile_idx The currently processed tile k index.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
/// ///
/// @return The number of neighbours. /// @return The number of neighbours.
/// ///
__device__ index_t WaitForNeighbours(index_t k_tiles, __device__ index_t WaitForNeighbours(index_t k_tiles, index_t k_tile_idx)
index_t k_tile_idx,
index_t output_tile_idx,
index_t output_tile_idx_offset)
{ {
// We have to wait for all workgroups to finish their partial results. // We have to wait for all workgroups to finish their partial results.
// First count how many "neighbour" workgroups we have to check. // First count how many "neighbour" workgroups we have to check.
...@@ -139,57 +105,48 @@ class StridedReductionTileLoop ...@@ -139,57 +105,48 @@ class StridedReductionTileLoop
if(neighbour_count > 0) if(neighbour_count > 0)
{ {
// Also count this workgroup index_t flag_sum = 0;
neighbour_count++; do
finished_block_flags_.wait_eq( {
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), flag_sum = 0;
neighbour_count); for(index_t i = 1; i <= neighbour_count; ++i)
{
flag_sum += finished_block_flags_.ld(GetWorkgroupFlagIdx() + i);
}
} while(flag_sum != neighbour_count);
} }
return neighbour_count; return neighbour_count;
} }
/// ///
/// @brief Wait until each workgroup has finished its work. /// @brief Wait until reduction workgroup has finished its work.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
/// ///
__device__ void __device__ void WaitForReduction()
WaitForReduction(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{ {
// Wait untill the counter has been reset. // Wait untill my counter has been reset.
finished_block_flags_.wait_eq( finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0);
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
} }
/// ///
/// @brief Reset flag counter to zero. /// @brief Reset flag counter to zero.
/// ///
/// @param[in] k_tiles The number of tiles in the reduced dimension. /// @param[in] neighbour_count The number of peer workgroups.
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
/// ///
__device__ void Reset(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset) __device__ void Reset(index_t neighbour_count)
{
for(index_t i = 0; i <= neighbour_count; ++i)
{ {
finished_block_flags_.reset( finished_block_flags_.reset(GetWorkgroupFlagIdx() + i);
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset)); }
} }
/// ///
/// @brief Gets the flag value. /// @brief Gets the flag value.
/// ///
/// @param[in] k_tiles The number of tiles in the reduced dimension. __device__ uint32_t GetFlagValue() const
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
///
__device__ uint32_t GetFlagValue(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{ {
return finished_block_flags_.ld( return finished_block_flags_.ld(GetWorkgroupFlagIdx());
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
} }
const index_t tile_count_; const index_t tile_count_;
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace ck { namespace ck {
struct workgroup_barrier struct workgroup_barrier
{ {
__device__ workgroup_barrier(volatile uint32_t* ptr) : base_ptr(ptr) {} __device__ workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
__device__ uint32_t ld(uint32_t offset) const __device__ uint32_t ld(uint32_t offset) const
{ {
...@@ -53,7 +53,7 @@ struct workgroup_barrier ...@@ -53,7 +53,7 @@ struct workgroup_barrier
{ {
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
while(atomicCAS(const_cast<uint32_t*>(base_ptr + offset), compare, value) != compare) {} while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
} }
__syncthreads(); __syncthreads();
} }
...@@ -68,7 +68,7 @@ struct workgroup_barrier ...@@ -68,7 +68,7 @@ struct workgroup_barrier
{ {
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
atomicAdd(const_cast<uint32_t*>(base_ptr + offset), 1); atomicAdd(base_ptr + offset, 1);
} }
__syncthreads(); __syncthreads();
} }
...@@ -82,6 +82,6 @@ struct workgroup_barrier ...@@ -82,6 +82,6 @@ struct workgroup_barrier
__syncthreads(); __syncthreads();
} }
volatile uint32_t* base_ptr; uint32_t* base_ptr;
}; };
} // namespace ck } // namespace ck
...@@ -166,22 +166,18 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -166,22 +166,18 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
// partial_result; // partial_result;
} }
const index_t output_tile_idx = work_scheduler.FlagFinished();
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
const index_t output_tile_idx_offset = __builtin_amdgcn_readfirstlane(offset / k_batch);
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
// The workgroup which processed first K tile accumulates results and stores to GMEM // The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock()) if(b2c_tile_map.IsFirstKSplitBlock())
{ {
// Wait untill all other blocks for this [M,N] tile store their results. // Wait untill all other blocks for this [M,N] tile store their results.
index_t neighbour_count = work_scheduler.WaitForNeighbours( index_t neighbour_count =
k_batch, b2c_tile_map.GetTileKIdx(), output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForNeighbours(k_batch, b2c_tile_map.GetTileKIdx());
// Accumulate partial results. We can have different # of workgroups to reduce, thus we // Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value. // read actual flag value.
for(index_t i = 1; i < neighbour_count; ++i) for(index_t i = 1; i <= neighbour_count; ++i)
{ {
// partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock + // partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
// i * MPerBlock * NPerBlock + // i * MPerBlock * NPerBlock +
...@@ -199,7 +195,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -199,7 +195,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
} }
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(neighbour_count);
// write result // write result
const index_t C_m_tile_offset = block_m_id * MPerBlock; const index_t C_m_tile_offset = block_m_id * MPerBlock;
...@@ -221,7 +217,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -221,7 +217,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
} }
else if(work_scheduler.HasTile()) else if(work_scheduler.HasTile())
{ {
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForReduction();
} }
} while(work_scheduler.HasTile()); } while(work_scheduler.HasTile());
...@@ -328,11 +324,7 @@ struct GroupedGemmStridedTileLoopReduce ...@@ -328,11 +324,7 @@ struct GroupedGemmStridedTileLoopReduce
gemm_descs_device_buf.ToDevice(gemm_descs.data()); gemm_descs_device_buf.ToDevice(gemm_descs.data());
DeviceMem gemm_workspace, gemm_flags; DeviceMem gemm_workspace, gemm_flags;
const index_t flag_count = grid_size;
const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile.
const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
gemm_workspace.Realloc(grid_size * MPerBlock * NPerBlock * sizeof(float)); gemm_workspace.Realloc(grid_size * MPerBlock * NPerBlock * sizeof(float));
gemm_flags.Realloc(flag_count * sizeof(uint32_t)); gemm_flags.Realloc(flag_count * sizeof(uint32_t));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment