Commit 6c5111b7 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

code changes for streamk with reduction

parent ab3885aa
......@@ -147,8 +147,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
hipGetErrorString(hipMemsetAsync(
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
}
const auto Run = [&](const auto& kernel) {
dim3 grid_dim;
if(arg.Grid_size < 0)
......@@ -195,10 +199,31 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
else
{
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
ave_time = launch_and_time_kernel(
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
else if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
char* workspace_semaphore =
reinterpret_cast<char*>(arg.p_workspace_) +
arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
sizeof(GemmAccDataType));
auto preprocess = [&]() {
hipGetErrorString(hipMemsetAsync(
workspace_semaphore,
0,
arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
stream_config.stream_id_));
};
ave_time = launch_and_time_kernel_with_preprocess(
stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
}
};
constexpr index_t minimum_occupancy =
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment