"...composable_kernel_rocm.git" did not exist on "1db756036584e305196403c0920ad711584cf017"
Commit f8cbbd1b authored by Adam Osewski's avatar Adam Osewski
Browse files

Change return type from inxed_t to uint32_t for GetFlagValue.

Update doc.
parent f41a265a
......@@ -65,18 +65,19 @@ class StridedReductionTileLoop
/// @brief Calculate this workgroup flag index.
///
/// @note Note this scheduler intentionaly does not have flag index as its member, since
/// the number of `k_tiles` may change when iterating (ie. in grouped gemm,
/// different groups may have different `k_tiles` in K dimension).
/// current workgroup may process tiles across different MN-output tiles or
/// acorss different GEMMs (grouped gemm).
///
/// @param[in] k_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index (of current GEMM).
/// @param[in] output_tile_idx_offset The output tile index offset.
/// @param[in] k_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) linear tile index (of current GEMM).
/// @param[in] output_tile_idx_offset The accumulated offset of output tiles from previous
/// GEMMs.
///
/// @return The workgroup flag index.
///
__device__ index_t GetWorkgroupFlagIdx(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
__device__ uint32_t GetWorkgroupFlagIdx(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{
return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
}
......@@ -91,8 +92,9 @@ class StridedReductionTileLoop
__device__ void
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
finished_block_flags_.inc(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
finished_block_flags_.inc(fidx);
}
///
......@@ -149,12 +151,12 @@ class StridedReductionTileLoop
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
///
__device__ index_t GetFlagValue(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
__device__ uint32_t GetFlagValue(index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{
return static_cast<index_t>(finished_block_flags_.ld(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset)));
return finished_block_flags_.ld(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
}
const index_t tile_count_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment