Commit 5b2eb1ab authored by Adam Osewski's avatar Adam Osewski
Browse files

Single flag per workgroup synchronization scheme.

parent f8ca9048
......@@ -156,32 +156,28 @@ __global__ void
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
const index_t output_tile_idx_offset = __builtin_amdgcn_readfirstlane(offset / k_batch);
// if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock())
{
GridwiseGemm::StorePartials(p_workspace, results_buffer);
}
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
work_scheduler.FlagFinished();
// The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock())
{
// Wait untill all other blocks for this [M,N] tile store their results.
index_t neighbour_count = work_scheduler.WaitForNeighbours(
k_batch, b2c_tile_map.GetTileKIdx(), output_tile_idx, output_tile_idx_offset);
index_t neighbour_count =
work_scheduler.WaitForNeighbours(k_batch, b2c_tile_map.GetTileKIdx());
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if(neighbour_count > 1)
GridwiseGemm::AccumulatePartials(p_workspace, results_buffer, neighbour_count);
if(neighbour_count > 0)
GridwiseGemm::AccumulatePartials(p_workspace, results_buffer, neighbour_count + 1);
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
work_scheduler.Reset(neighbour_count);
const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
const auto stride_e = gemm_desc_ptr[group_id].StrideE;
......@@ -210,7 +206,7 @@ __global__ void
}
else if(work_scheduler.HasTile())
{
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
work_scheduler.WaitForReduction();
}
} while(work_scheduler.HasTile());
#else
......@@ -752,7 +748,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size);
std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
std::size_t flag_count = grid_size;
if(stream_config.log_level_ > 0)
{
......@@ -993,7 +989,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
{
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
}
int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
int flag_count = grid_size;
// This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in
......
......@@ -32,8 +32,7 @@ enum struct WorkSchedulingPolicy
class StridedReductionTileLoop
{
public:
__device__ StridedReductionTileLoop(index_t tile_count,
volatile uint32_t* const __restrict__ p_flags)
__device__ StridedReductionTileLoop(index_t tile_count, uint32_t* const __restrict__ p_flags)
: tile_count_{tile_count},
tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
tile_id_{get_block_1d_id() * tiles_per_block_},
......@@ -54,62 +53,29 @@ class StridedReductionTileLoop
return HasTile();
}
__device__ index_t GetFlagCount(index_t k_tiles) const
{
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
return (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
}
__device__ index_t GetFlagCount() const { return get_grid_size(); }
///
/// @brief Calculate this workgroup flag index.
///
/// @note Note this scheduler intentionaly does not have flag index as its member, since
/// current workgroup may process tiles across different MN-output tiles or
/// acorss different GEMMs (grouped gemm).
///
/// @param[in] k_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) linear tile index (of current GEMM).
/// @param[in] output_tile_idx_offset The accumulated offset of output tiles from previous
/// GEMMs.
/// @brief Get this workgroup flag index.
///
/// @return The workgroup flag index.
///
__device__ uint32_t GetWorkgroupFlagIdx(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{
return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
}
__device__ uint32_t GetWorkgroupFlagIdx() const { return static_cast<uint32_t>(blockIdx.x); }
///
/// @brief Flag each workgroup that has finished its work.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
__device__ void
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
finished_block_flags_.inc(fidx);
}
__device__ void FlagFinished() { finished_block_flags_.inc(GetWorkgroupFlagIdx()); }
///
/// @brief Wait until each workgroup has finished its work.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tile_idx The currently processed tile k index.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tile_idx The currently processed tile k index.
///
/// @return The number of neighbours.
///
__device__ index_t WaitForNeighbours(index_t k_tiles,
index_t k_tile_idx,
index_t output_tile_idx,
index_t output_tile_idx_offset)
__device__ index_t WaitForNeighbours(index_t k_tiles, index_t k_tile_idx)
{
// We have to wait for all workgroups to finish their partial results.
// First count how many "neighbour" workgroups we have to check.
......@@ -139,57 +105,48 @@ class StridedReductionTileLoop
if(neighbour_count > 0)
{
// Also count this workgroup
neighbour_count++;
finished_block_flags_.wait_eq(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
neighbour_count);
index_t flag_sum = 0;
do
{
flag_sum = 0;
for(index_t i = 1; i <= neighbour_count; ++i)
{
flag_sum += finished_block_flags_.ld(GetWorkgroupFlagIdx() + i);
}
} while(flag_sum != neighbour_count);
}
return neighbour_count;
}
///
/// @brief Wait until each workgroup has finished its work.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
/// @brief Wait until reduction workgroup has finished its work.
///
__device__ void
WaitForReduction(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
__device__ void WaitForReduction()
{
// Wait untill the counter has been reset.
finished_block_flags_.wait_eq(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
// Wait untill my counter has been reset.
finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0);
}
///
/// @brief Reset flag counter to zero.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
/// @param[in] neighbour_count The number of peer workgroups.
///
__device__ void Reset(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
__device__ void Reset(index_t neighbour_count)
{
finished_block_flags_.reset(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
for(index_t i = 0; i <= neighbour_count; ++i)
{
finished_block_flags_.reset(GetWorkgroupFlagIdx() + i);
}
}
///
/// @brief Gets the flag value.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
///
__device__ uint32_t GetFlagValue(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
__device__ uint32_t GetFlagValue() const
{
return finished_block_flags_.ld(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
return finished_block_flags_.ld(GetWorkgroupFlagIdx());
}
const index_t tile_count_;
......
......@@ -5,7 +5,7 @@
namespace ck {
struct workgroup_barrier
{
__device__ workgroup_barrier(volatile uint32_t* ptr) : base_ptr(ptr) {}
__device__ workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
__device__ uint32_t ld(uint32_t offset) const
{
......@@ -53,7 +53,7 @@ struct workgroup_barrier
{
if(threadIdx.x == 0)
{
while(atomicCAS(const_cast<uint32_t*>(base_ptr + offset), compare, value) != compare) {}
while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
}
__syncthreads();
}
......@@ -68,7 +68,7 @@ struct workgroup_barrier
{
if(threadIdx.x == 0)
{
atomicAdd(const_cast<uint32_t*>(base_ptr + offset), 1);
atomicAdd(base_ptr + offset, 1);
}
__syncthreads();
}
......@@ -82,6 +82,6 @@ struct workgroup_barrier
__syncthreads();
}
volatile uint32_t* base_ptr;
uint32_t* base_ptr;
};
} // namespace ck
......@@ -166,22 +166,18 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
// partial_result;
}
const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
const index_t output_tile_idx_offset = __builtin_amdgcn_readfirstlane(offset / k_batch);
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
work_scheduler.FlagFinished();
// The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock())
{
// Wait untill all other blocks for this [M,N] tile store their results.
index_t neighbour_count = work_scheduler.WaitForNeighbours(
k_batch, b2c_tile_map.GetTileKIdx(), output_tile_idx, output_tile_idx_offset);
index_t neighbour_count =
work_scheduler.WaitForNeighbours(k_batch, b2c_tile_map.GetTileKIdx());
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
for(index_t i = 1; i < neighbour_count; ++i)
for(index_t i = 1; i <= neighbour_count; ++i)
{
// partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
// i * MPerBlock * NPerBlock +
......@@ -199,7 +195,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
}
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
work_scheduler.Reset(neighbour_count);
// write result
const index_t C_m_tile_offset = block_m_id * MPerBlock;
......@@ -221,7 +217,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
}
else if(work_scheduler.HasTile())
{
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
work_scheduler.WaitForReduction();
}
} while(work_scheduler.HasTile());
......@@ -328,11 +324,7 @@ struct GroupedGemmStridedTileLoopReduce
gemm_descs_device_buf.ToDevice(gemm_descs.data());
DeviceMem gemm_workspace, gemm_flags;
const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile.
const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
const index_t flag_count = grid_size;
gemm_workspace.Realloc(grid_size * MPerBlock * NPerBlock * sizeof(float));
gemm_flags.Realloc(flag_count * sizeof(uint32_t));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment