Commit f8cbbd1b authored by Adam Osewski's avatar Adam Osewski
Browse files

Change return type from inxed_t to uint32_t for GetFlagValue.

Update doc.
parent f41a265a
...@@ -65,18 +65,19 @@ class StridedReductionTileLoop ...@@ -65,18 +65,19 @@ class StridedReductionTileLoop
/// @brief Calculate this workgroup flag index. /// @brief Calculate this workgroup flag index.
/// ///
/// @note Note this scheduler intentionaly does not have flag index as its member, since /// @note Note this scheduler intentionaly does not have flag index as its member, since
/// the number of `k_tiles` may change when iterating (ie. in grouped gemm, /// current workgroup may process tiles across different MN-output tiles or
/// different groups may have different `k_tiles` in K dimension). /// acorss different GEMMs (grouped gemm).
/// ///
/// @param[in] k_tiles The number of data tiles in the reduced dimension. /// @param[in] k_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index (of current GEMM). /// @param[in] output_tile_idx The output (MN) linear tile index (of current GEMM).
/// @param[in] output_tile_idx_offset The output tile index offset. /// @param[in] output_tile_idx_offset The accumulated offset of output tiles from previous
/// GEMMs.
/// ///
/// @return The workgroup flag index. /// @return The workgroup flag index.
/// ///
__device__ index_t GetWorkgroupFlagIdx(index_t k_tiles, __device__ uint32_t GetWorkgroupFlagIdx(index_t k_tiles,
index_t output_tile_idx, index_t output_tile_idx,
index_t output_tile_idx_offset) const index_t output_tile_idx_offset) const
{ {
return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles); return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
} }
...@@ -91,8 +92,9 @@ class StridedReductionTileLoop ...@@ -91,8 +92,9 @@ class StridedReductionTileLoop
__device__ void __device__ void
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset) FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{ {
finished_block_flags_.inc( const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
finished_block_flags_.inc(fidx);
} }
/// ///
...@@ -149,12 +151,12 @@ class StridedReductionTileLoop ...@@ -149,12 +151,12 @@ class StridedReductionTileLoop
/// @param[in] output_tile_idx The output (MN) tile index. /// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset. /// @param[in] output_tile_idx_offset The output tile index offset.
/// ///
__device__ index_t GetFlagValue(index_t k_tiles, __device__ uint32_t GetFlagValue(index_t k_tiles,
index_t output_tile_idx, index_t output_tile_idx,
index_t output_tile_idx_offset) const index_t output_tile_idx_offset) const
{ {
return static_cast<index_t>(finished_block_flags_.ld( return finished_block_flags_.ld(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset))); GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
} }
const index_t tile_count_; const index_t tile_count_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment