Commit 98220c32 authored by Adam Osewski's avatar Adam Osewski
Browse files

Refactor out FlagCount function.

parent 991e44ee
......@@ -54,6 +54,13 @@ class StridedReductionTileLoop
return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
}
__device__ index_t GetFlagCount(index_t k_tiles) const
{
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
return (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
}
///
/// @brief Calculate this workgroup flag index.
///
......@@ -71,10 +78,7 @@ class StridedReductionTileLoop
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
const index_t flag_count = (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
return (output_tile_idx + output_tile_idx_offset) % flag_count;
return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
}
///
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment