Commit 73adb83d authored by Adam Osewski's avatar Adam Osewski
Browse files

Do not synchronize when it's not necessary.

parent c5c95578
...@@ -133,6 +133,7 @@ __global__ void ...@@ -133,6 +133,7 @@ __global__ void
// Iterate over K dimension for this [M,N] tile // Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile // still in the same GEMM && the same [M,N] tile
auto k_tiles = work_scheduler.GetNextKTiles(k_batch, b2c_tile_map.GetTileKIdx()); auto k_tiles = work_scheduler.GetNextKTiles(k_batch, b2c_tile_map.GetTileKIdx());
work_scheduler.SetIsSyncNeeded(k_tiles, k_batch);
// just accumulate results in registers! // just accumulate results in registers!
GridwiseGemm::template RunGEMM(p_a_grid, GridwiseGemm::template RunGEMM(p_a_grid,
...@@ -874,6 +875,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -874,6 +875,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<< std::string(ALayout::name)[0] << "," << std::string(ALayout::name)[0] << ","
<< std::string(BLayout::name)[0] << "," << std::string(BLayout::name)[0] << ","
<< std::string(ELayout::name)[0] << "," << std::string(ELayout::name)[0] << ","
<< NumGemmKPrefetchStage << ", "
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp" #include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
......
...@@ -33,11 +33,13 @@ class StridedReductionTileLoop ...@@ -33,11 +33,13 @@ class StridedReductionTileLoop
{ {
public: public:
__device__ StridedReductionTileLoop(index_t tile_count, uint32_t* const __restrict__ p_flags) __device__ StridedReductionTileLoop(index_t tile_count, uint32_t* const __restrict__ p_flags)
: tile_count_{tile_count}, : tile_count_{__builtin_amdgcn_readfirstlane(tile_count)},
tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()}, tiles_per_block_{__builtin_amdgcn_readfirstlane((tile_count_ + get_grid_size() - 1) /
tile_id_{get_block_1d_id() * tiles_per_block_}, get_grid_size())},
block_tile_idx_{0}, tile_id_{__builtin_amdgcn_readfirstlane(get_block_1d_id() * tiles_per_block_)},
finished_block_flags_{p_flags} block_tile_idx_{__builtin_amdgcn_readfirstlane(0)},
finished_block_flags_{p_flags},
is_sync_needed_{1}
{ {
} }
...@@ -80,11 +82,18 @@ class StridedReductionTileLoop ...@@ -80,11 +82,18 @@ class StridedReductionTileLoop
/// ///
/// @brief Flag each workgroup that has finished its work. /// @brief Flag each workgroup that has finished its work.
/// ///
__device__ void FlagFinished() { finished_block_flags_.inc(GetWorkgroupFlagIdx()); } __device__ void FlagFinished()
{
if(is_sync_needed_)
finished_block_flags_.inc(GetWorkgroupFlagIdx());
}
/// ///
/// @brief Wait until each workgroup has finished its work. /// @brief Wait until each workgroup has finished its work.
/// ///
/// @note This function assumes it's called by the WGP which processes the first
/// k-tile.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension. /// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tile_idx The currently processed tile k index. /// @param[in] k_tile_idx The currently processed tile k index.
/// ///
...@@ -95,6 +104,10 @@ class StridedReductionTileLoop ...@@ -95,6 +104,10 @@ class StridedReductionTileLoop
// We have to wait for all workgroups to finish their partial results. // We have to wait for all workgroups to finish their partial results.
// First count how many "neighbour" workgroups we have to check. // First count how many "neighbour" workgroups we have to check.
index_t neighbour_count = 0; index_t neighbour_count = 0;
if(!is_sync_needed_)
return neighbour_count;
if(tiles_per_block_ < k_tiles) if(tiles_per_block_ < k_tiles)
{ {
// Since we can have deviation (+/-1) in neighbours number // Since we can have deviation (+/-1) in neighbours number
...@@ -135,8 +148,9 @@ class StridedReductionTileLoop ...@@ -135,8 +148,9 @@ class StridedReductionTileLoop
/// ///
__device__ void WaitForReduction() __device__ void WaitForReduction()
{ {
// Wait untill my counter has been reset. if(is_sync_needed_)
finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0); // Wait untill my counter has been reset.
finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0);
} }
/// ///
...@@ -146,9 +160,12 @@ class StridedReductionTileLoop ...@@ -146,9 +160,12 @@ class StridedReductionTileLoop
/// ///
__device__ void Reset(index_t neighbour_count) __device__ void Reset(index_t neighbour_count)
{ {
for(index_t i = 0; i <= neighbour_count; ++i) if(is_sync_needed_)
{ {
finished_block_flags_.reset(GetWorkgroupFlagIdx() + i); for(index_t i = 0; i <= neighbour_count; ++i)
{
finished_block_flags_.reset(GetWorkgroupFlagIdx() + i);
}
} }
} }
...@@ -160,11 +177,17 @@ class StridedReductionTileLoop ...@@ -160,11 +177,17 @@ class StridedReductionTileLoop
return finished_block_flags_.ld(GetWorkgroupFlagIdx()); return finished_block_flags_.ld(GetWorkgroupFlagIdx());
} }
__device__ void SetIsSyncNeeded(index_t next_k_tiles, index_t k_tiles)
{
is_sync_needed_ = __builtin_amdgcn_readfirstlane(next_k_tiles == k_tiles ? 0 : 1);
}
const index_t tile_count_; const index_t tile_count_;
const index_t tiles_per_block_; const index_t tiles_per_block_;
index_t tile_id_; index_t tile_id_;
index_t block_tile_idx_; index_t block_tile_idx_;
workgroup_barrier finished_block_flags_; workgroup_barrier finished_block_flags_;
index_t is_sync_needed_;
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment