Commit 7316bd15 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add more functionality to work scheduler.

parent 70974baf
...@@ -42,6 +42,11 @@ class StridedReductionTileLoop ...@@ -42,6 +42,11 @@ class StridedReductionTileLoop
{ {
} }
__device__ bool HasTile() const
{
return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
}
__device__ bool GetNextTile() __device__ bool GetNextTile()
{ {
tile_id_++; tile_id_++;
...@@ -53,45 +58,99 @@ class StridedReductionTileLoop ...@@ -53,45 +58,99 @@ class StridedReductionTileLoop
/// @brief Calculate this workgroup flag index. /// @brief Calculate this workgroup flag index.
/// ///
/// @note Note this scheduler intentionaly does not have flag index as its member, since /// @note Note this scheduler intentionaly does not have flag index as its member, since
/// the number of `dim_tiles` may change when iterating (ie. in grouped gemm, /// the number of `k_tiles` may change when iterating (ie. in grouped gemm,
/// different groups may have different `dim_tiles` in K dimension). /// different groups may have different `k_tiles` in K dimension).
/// ///
/// @param[in] dim_tiles The number of data tiles in the reduced dimension. /// @param[in] k_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index. /// @param[in] output_tile_idx The output (MN) tile index (of current GEMM).
/// @param[in] output_tile_idx_offset The output tile index offset.
/// ///
/// @return The workgroup flag index. /// @return The workgroup flag index.
/// ///
__device__ index_t GetWorkgroupFlagIdx(index_t dim_tiles, index_t output_tile_idx) const __device__ index_t GetWorkgroupFlagIdx(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{ {
// This is the number of MN-output tiles which we cover with workgroups. // This is the number of MN-output tiles which we cover with workgroups.
// We launch dim_tiles (k_batch) / tiles_per_block workgroups for each output tile. // We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
const index_t flag_count = (get_grid_size() * tiles_per_block_ + dim_tiles - 1) / dim_tiles; const index_t flag_count = (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
return output_tile_idx % flag_count; return (output_tile_idx + output_tile_idx_offset) % flag_count;
} }
/// ///
/// @brief Flag each workgroup that has finished its work. /// @brief Flag each workgroup that has finished its work.
/// ///
/// @param[in] dim_tiles The number of tiles in the reduced dimension. /// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index /// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
__device__ void
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
finished_block_flags_.inc(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
}
///
/// @brief Wait until each workgroup has finished its work.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
/// ///
__device__ void FlagFinished(index_t dim_tiles, index_t output_tile_idx) __device__ void
WaitForNeighbours(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{ {
finished_block_flags_.inc(GetWorkgroupFlagIdx(dim_tiles, output_tile_idx)); // Wait untill all workgroups finish
const index_t workgroups_per_dim = (k_tiles + tiles_per_block_ - 1) / tiles_per_block_;
// We use < because for some cases we may have +1 more workgroups per dim.
// Ie when k_tiles = 5, tiles_per_block = 3.
finished_block_flags_.wait_lt(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
workgroups_per_dim);
} }
/// ///
/// @brief Wait until each workgroup has finished its work. /// @brief Wait until each workgroup has finished its work.
/// ///
/// @param[in] dim_tiles The number of tiles in the reduced dimension. /// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index /// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
__device__ void
WaitForReduction(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
// Wait untill the counter has been reset.
finished_block_flags_.wait_eq(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
}
///
/// @brief Reset flag counter to zero.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
///
__device__ void Reset(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
finished_block_flags_.reset(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
}
///
/// @brief Gets the flag value.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
/// ///
__device__ void WaitForNeighbours(index_t dim_tiles, index_t output_tile_idx) __device__ index_t GetFlagValue(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{ {
// Wait untill all workgroups finish and reset counter. return static_cast<index_t>(finished_block_flags_.ld(
const index_t workgroups_per_dim = (dim_tiles + tiles_per_block_ - 1) / tiles_per_block_; GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset)));
finished_block_flags_.wait_set(
GetWorkgroupFlagIdx(dim_tiles, output_tile_idx), workgroups_per_dim, 0);
} }
const index_t tile_count_; const index_t tile_count_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment