Commit 986f1d2b authored by Adam Osewski's avatar Adam Osewski
Browse files

Add more functionality to work scheduler.

parent 33ac23c6
......@@ -42,6 +42,11 @@ class StridedReductionTileLoop
{
}
__device__ bool HasTile() const
{
return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
}
__device__ bool GetNextTile()
{
tile_id_++;
......@@ -53,45 +58,99 @@ class StridedReductionTileLoop
/// @brief Calculate this workgroup flag index.
///
/// @note Note this scheduler intentionaly does not have flag index as its member, since
/// the number of `dim_tiles` may change when iterating (ie. in grouped gemm,
/// different groups may have different `dim_tiles` in K dimension).
/// the number of `k_tiles` may change when iterating (ie. in grouped gemm,
/// different groups may have different `k_tiles` in K dimension).
///
/// @param[in] dim_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] k_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index (of current GEMM).
/// @param[in] output_tile_idx_offset The output tile index offset.
///
/// @return The workgroup flag index.
///
__device__ index_t GetWorkgroupFlagIdx(index_t dim_tiles, index_t output_tile_idx) const
__device__ index_t GetWorkgroupFlagIdx(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{
// This is the number of MN-output tiles which we cover with workgroups.
// We launch dim_tiles (k_batch) / tiles_per_block workgroups for each output tile.
const index_t flag_count = (get_grid_size() * tiles_per_block_ + dim_tiles - 1) / dim_tiles;
return output_tile_idx % flag_count;
// We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
const index_t flag_count = (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
return (output_tile_idx + output_tile_idx_offset) % flag_count;
}
///
/// @brief Flag each workgroup that has finished its work.
///
/// @param[in] dim_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
__device__ void
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
finished_block_flags_.inc(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
}
///
/// @brief Wait until each workgroup has finished its work.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
__device__ void FlagFinished(index_t dim_tiles, index_t output_tile_idx)
__device__ void
WaitForNeighbours(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
finished_block_flags_.inc(GetWorkgroupFlagIdx(dim_tiles, output_tile_idx));
// Wait untill all workgroups finish
const index_t workgroups_per_dim = (k_tiles + tiles_per_block_ - 1) / tiles_per_block_;
// We use < because for some cases we may have +1 more workgroups per dim.
// Ie when k_tiles = 5, tiles_per_block = 3.
finished_block_flags_.wait_lt(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
workgroups_per_dim);
}
///
/// @brief Wait until each workgroup has finished its work.
///
/// @param[in] dim_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
__device__ void
WaitForReduction(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
// Wait untill the counter has been reset.
finished_block_flags_.wait_eq(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
}
///
/// @brief Reset flag counter to zero.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
///
__device__ void Reset(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
finished_block_flags_.reset(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
}
///
/// @brief Gets the flag value.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
///
__device__ void WaitForNeighbours(index_t dim_tiles, index_t output_tile_idx)
__device__ index_t GetFlagValue(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{
// Wait untill all workgroups finish and reset counter.
const index_t workgroups_per_dim = (dim_tiles + tiles_per_block_ - 1) / tiles_per_block_;
finished_block_flags_.wait_set(
GetWorkgroupFlagIdx(dim_tiles, output_tile_idx), workgroups_per_dim, 0);
return static_cast<index_t>(finished_block_flags_.ld(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset)));
}
const index_t tile_count_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment