Commit 73adb83d authored by Adam Osewski's avatar Adam Osewski
Browse files

Do not synchronize when it's not necessary.

parent c5c95578
......@@ -133,6 +133,7 @@ __global__ void
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
auto k_tiles = work_scheduler.GetNextKTiles(k_batch, b2c_tile_map.GetTileKIdx());
work_scheduler.SetIsSyncNeeded(k_tiles, k_batch);
// just accumulate results in registers!
GridwiseGemm::template RunGEMM(p_a_grid,
......@@ -874,6 +875,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<< std::string(ALayout::name)[0] << ","
<< std::string(BLayout::name)[0] << ","
<< std::string(ELayout::name)[0] << ","
<< NumGemmKPrefetchStage << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck {
......
......@@ -33,11 +33,13 @@ class StridedReductionTileLoop
{
public:
__device__ StridedReductionTileLoop(index_t tile_count, uint32_t* const __restrict__ p_flags)
: tile_count_{tile_count},
tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
tile_id_{get_block_1d_id() * tiles_per_block_},
block_tile_idx_{0},
finished_block_flags_{p_flags}
: tile_count_{__builtin_amdgcn_readfirstlane(tile_count)},
tiles_per_block_{__builtin_amdgcn_readfirstlane((tile_count_ + get_grid_size() - 1) /
get_grid_size())},
tile_id_{__builtin_amdgcn_readfirstlane(get_block_1d_id() * tiles_per_block_)},
block_tile_idx_{__builtin_amdgcn_readfirstlane(0)},
finished_block_flags_{p_flags},
is_sync_needed_{1}
{
}
......@@ -80,11 +82,18 @@ class StridedReductionTileLoop
///
/// @brief Flag each workgroup that has finished its work.
///
__device__ void FlagFinished() { finished_block_flags_.inc(GetWorkgroupFlagIdx()); }
__device__ void FlagFinished()
{
if(is_sync_needed_)
finished_block_flags_.inc(GetWorkgroupFlagIdx());
}
///
/// @brief Wait until each workgroup has finished its work.
///
/// @note This function assumes it's called by the WGP which processes the first
/// k-tile.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tile_idx The currently processed tile k index.
///
......@@ -95,6 +104,10 @@ class StridedReductionTileLoop
// We have to wait for all workgroups to finish their partial results.
// First count how many "neighbour" workgroups we have to check.
index_t neighbour_count = 0;
if(!is_sync_needed_)
return neighbour_count;
if(tiles_per_block_ < k_tiles)
{
// Since we can have deviation (+/-1) in neighbours number
......@@ -135,8 +148,9 @@ class StridedReductionTileLoop
///
__device__ void WaitForReduction()
{
// Wait untill my counter has been reset.
finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0);
if(is_sync_needed_)
// Wait untill my counter has been reset.
finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0);
}
///
......@@ -146,9 +160,12 @@ class StridedReductionTileLoop
///
__device__ void Reset(index_t neighbour_count)
{
for(index_t i = 0; i <= neighbour_count; ++i)
if(is_sync_needed_)
{
finished_block_flags_.reset(GetWorkgroupFlagIdx() + i);
for(index_t i = 0; i <= neighbour_count; ++i)
{
finished_block_flags_.reset(GetWorkgroupFlagIdx() + i);
}
}
}
......@@ -160,11 +177,17 @@ class StridedReductionTileLoop
return finished_block_flags_.ld(GetWorkgroupFlagIdx());
}
__device__ void SetIsSyncNeeded(index_t next_k_tiles, index_t k_tiles)
{
is_sync_needed_ = __builtin_amdgcn_readfirstlane(next_k_tiles == k_tiles ? 0 : 1);
}
const index_t tile_count_;
const index_t tiles_per_block_;
index_t tile_id_;
index_t block_tile_idx_;
workgroup_barrier finished_block_flags_;
index_t is_sync_needed_;
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment