Commit e954c206 authored by Adam Osewski's avatar Adam Osewski
Browse files

Clean up and change how neighbours are counted.

parent 7e71ea99
...@@ -79,8 +79,8 @@ class StridedReductionTileLoop ...@@ -79,8 +79,8 @@ class StridedReductionTileLoop
index_t output_tile_idx, index_t output_tile_idx,
index_t output_tile_idx_offset) const index_t output_tile_idx_offset) const
{ {
// return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles); return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
return output_tile_idx + output_tile_idx_offset; // return output_tile_idx + output_tile_idx_offset;
} }
/// ///
...@@ -93,7 +93,7 @@ class StridedReductionTileLoop ...@@ -93,7 +93,7 @@ class StridedReductionTileLoop
__device__ void __device__ void
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset) FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{ {
/* [[maybe_unused]] */const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset); const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
finished_block_flags_.inc(fidx); finished_block_flags_.inc(fidx);
} }
...@@ -101,21 +101,51 @@ class StridedReductionTileLoop ...@@ -101,21 +101,51 @@ class StridedReductionTileLoop
/// @brief Wait until each workgroup has finished its work. /// @brief Wait until each workgroup has finished its work.
/// ///
/// @param[in] k_tiles The number of tiles in the reduced dimension. /// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tile_idx The currently processed tile k index.
/// @param[in] output_tile_idx The output (MN) tile index /// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset /// @param[in] output_tile_idx_offset The output tile index offset
/// ///
__device__ void /// @return The number of neighbours.
WaitForNeighbours(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset) ///
__device__ index_t WaitForNeighbours(index_t k_tiles,
index_t k_tile_idx,
index_t output_tile_idx,
index_t output_tile_idx_offset)
{ {
// Wait untill all workgroups finish // We have to wait for all workgroups to finish their partial results.
const index_t workgroups_per_dim = (k_tiles + tiles_per_block_ - 1) / tiles_per_block_; // First count how many "neighbour" workgroups we have to check.
// We use < because for some cases we may have +1 more workgroups per dim. index_t neighbour_count = 0;
// Ie when k_tiles = 5, tiles_per_block = 3. if(tiles_per_block_ < k_tiles)
finished_block_flags_.wait_lt( {
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), // Since we can have deviation (+1) in neighbours number
workgroups_per_dim); // we calculate how many workgroups are needed to process the k-tiles left.
neighbour_count = (k_tiles - k_tile_idx - 1 + tiles_per_block_ - 1) / tiles_per_block_;
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset); }
// If we have more tiles to process than the reduction dimension size,
// then the number of neighbours depends on first K-tile workgroup block tile idx.
else
{
if(block_tile_idx_ == tiles_per_block_)
{
// If we just finished work per workgroup then check at which k-idx we are.
neighbour_count = (k_tile_idx < k_tiles - 1) ? 1 : 0;
}
else
{
// If we have still tiles to process then it means that we already processed
// whole K-dim.
neighbour_count = 0;
}
}
if(neighbour_count > 0)
{
finished_block_flags_.wait_lt(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
neighbour_count);
}
return neighbour_count;
} }
/// ///
...@@ -131,8 +161,6 @@ class StridedReductionTileLoop ...@@ -131,8 +161,6 @@ class StridedReductionTileLoop
// Wait untill the counter has been reset. // Wait untill the counter has been reset.
finished_block_flags_.wait_eq( finished_block_flags_.wait_eq(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0); GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
} }
/// ///
...@@ -146,8 +174,6 @@ class StridedReductionTileLoop ...@@ -146,8 +174,6 @@ class StridedReductionTileLoop
{ {
finished_block_flags_.reset( finished_block_flags_.reset(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset)); GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
} }
/// ///
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment