Commit 563c1e23 authored by Astha Rai's avatar Astha Rai
Browse files

Merge branch 'codegen_hiprtc' of github.com:ROCm/composable_kernel into codegen_hiprtc

parents fb1bc1bf b14b1973
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
// kernel function Blockers:
// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
template <typename GridwiseGemm,
typename BatchedGemmArg,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
// populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer
karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
});
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_ds_grid,
karg.p_c_grid + c_batch_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
}
template <typename GridwiseGemm,
typename BatchedGemmArg,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
// populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer
karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
});
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_ds_grid,
karg.p_c_grid + c_batch_offset,
p_shared_0,
p_shared_1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
}
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = BDataType,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
: public DeviceBatchedGemmV2MultiD<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(BatchStrideA_) * g_idx;
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(BatchStrideB_) * g_idx;
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
std::array<long_index_t, NumDTensor> ds_offset_;
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
ds_offset_[i] = static_cast<long_index_t>(BatchStrideDs_[i]) * g_idx;
});
return ds_offset_;
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(BatchStrideC_) * g_idx;
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
const std::array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideC_;
};
struct Argument : public GridwiseGemm::Argument
{
index_t Batch;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t BatchStrideA_,
index_t BatchStrideB_,
const std::array<ck::index_t, NumDTensor>& BatchStrideDs_,
index_t BatchStrideE_,
index_t Batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: GridwiseGemm::Argument{p_a_grid_,
p_b_grid_,
p_ds_grid_,
p_e_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideDs_,
StrideE_,
1,
a_element_op_,
b_element_op_,
c_element_op_},
Batch{Batch_},
compute_ptr_offset_of_batch{
BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_}
{
}
};
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1)
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
std::array<std::size_t, NumDTensor> DsSize;
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) * arg.Batch;
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) * arg.Batch;
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
BatchStrideA,
BatchStrideB,
BatchStrideDs,
BatchStrideE,
Batch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
const std::array<ck::index_t, NumDTensor>& StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
BatchStrideA,
BatchStrideB,
BatchStrideDs,
BatchStrideE,
Batch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceBatchedGemmXdlUniversal"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -131,6 +131,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
......@@ -147,26 +148,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
hipGetErrorString(hipMemsetAsync(
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
hip_check_error(hipMemsetAsync(
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
}
const auto Run = [&](const auto& kernel) {
dim3 grid_dim;
if(arg.Grid_size < 0)
{
int occupancy, num_cu;
hipError_t rtn;
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, 0);
hip_check_error(rtn);
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
rtn = hipGetDevice(&dev);
hip_check_error(rtn);
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
num_cu = dev_prop.multiProcessorCount;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
arg.Grid_size = num_cu * occupancy;
grid_dim = arg.Grid_size;
}
......@@ -196,8 +198,31 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
else
{
ave_time = launch_and_time_kernel(
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
ave_time = launch_and_time_kernel(
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
else if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
char* workspace_semaphore =
reinterpret_cast<char*>(arg.p_workspace_) +
arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
sizeof(GemmAccDataType));
auto preprocess = [&]() {
hipMemsetAsync(
workspace_semaphore,
0,
// sizeof(uint32_t),
arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
stream_config.stream_id_);
};
ave_time = launch_and_time_kernel_with_preprocess(
stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
}
};
......@@ -211,14 +236,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
......@@ -340,53 +363,49 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
......@@ -396,14 +415,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
......@@ -418,6 +434,29 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType));
}
else
{
return 0;
}
}
void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
}
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
......@@ -464,8 +503,205 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
CElementwiseOperation)
{
return Argument{
p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
int occupancy, num_cu;
const auto calculate_grid_size = [&](const auto& kernel) {
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
Grid_size = num_cu * occupancy;
};
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
calculate_grid_size(kernel);
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
calculate_grid_size(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
calculate_grid_size(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
calculate_grid_size(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
calculate_grid_size(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
calculate_grid_size(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
calculate_grid_size(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
calculate_grid_size(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
calculate_grid_size(kernel);
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
calculate_grid_size(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
calculate_grid_size(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
calculate_grid_size(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
calculate_grid_size(kernel);
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
calculate_grid_size(kernel);
}
}
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size};
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -381,10 +381,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
{
tildes = {i_ztilde, i_ytilde, i_xtilde};
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
......@@ -750,6 +746,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
}
}
// check number of dimension, only implemented for 2D and 3D now
if(NDimSpatial != 2 && NDimSpatial != 3)
{
return false;
}
return true;
}
......
......@@ -18,7 +18,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
......@@ -78,17 +77,17 @@ template <typename ALayout,
// TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
enable_if_t<AK1 == BK1, bool> = false>
struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
: public DeviceGroupedGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
: public DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage;
......@@ -530,7 +529,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
index_t skipped_group_count_;
index_t grid_size_;
// Pointer to device memory with GEMM kernel arguments.
const void* p_dev_gemm_args_;
void* p_dev_gemm_kargs_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -566,7 +565,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_args,
void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{})
{
......@@ -621,7 +620,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
///
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(arg.p_dev_gemm_args_ == nullptr)
if(arg.p_dev_gemm_kargs_ == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
......@@ -637,7 +636,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
throw std::runtime_error(err.str());
}
return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config);
return Run(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config);
}
float Run(const BaseArgument* p_arg,
......@@ -723,7 +722,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
template <bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_kargs,
void* dev_gemm_workspace,
const StreamConfig& stream_config) const
{
......@@ -746,7 +745,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return LaunchKernel(gemm_kernel,
elementwise_kernel,
arg,
dev_gemm_args,
dev_gemm_kargs,
dev_gemm_workspace,
stream_config);
}
......@@ -755,12 +754,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float LaunchKernel(const KernelFunction& gemm_kernel,
const KernelFunction2& elementwise_kernel,
const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_kargs,
[[maybe_unused]] void* dev_gemm_workspace,
const StreamConfig& stream_config) const
{
float time{0.f};
hip_check_error(
hipMemcpyWithStream(dev_gemm_kargs,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
auto preprocess = [&]() {
hip_check_error(hipMemsetAsync(
dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
......@@ -774,7 +780,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
cast_pointer_to_constant_address_space(dev_gemm_kargs),
arg.gemm_kernel_args_.size(),
arg.a_element_op_,
arg.b_element_op_,
......@@ -930,18 +936,30 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return str.str();
}
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args,
arg.gemm_kernel_args_.data(),
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->p_dev_gemm_kargs_ = p_dev_kernel_args;
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
......@@ -974,17 +992,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
[[deprecated]] static void SetKBatchSize(Argument& arg, index_t kbatch)
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
arg.UpdateKBatch(kbatch);
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
sizeof(GemmTransKernelArg);
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->UpdateKBatch(kbatch);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
};
......
......@@ -20,7 +20,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace ck {
......@@ -522,7 +521,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
ComputeTypeA,
ComputeTypeB>;
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
using KernelArguments = GroupedGemmKernelArgument<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
......@@ -936,12 +935,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return str.str();
}
void SetDeviceKernelArgs(Argument& arg,
void* p_dev_kernel_args,
const void* p_host_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args,
p_host_kernel_args,
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
}
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
void* p_dev_kernel_args,
const void* p_host_kernel_args) const override
{
return SetDeviceKernelArgs(
*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args, p_host_kernel_args);
}
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}
......
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -717,7 +717,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg);
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(p_arg_)
{
return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return GetWorkSpaceSize(p_arg);
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
};
......
......@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
// TODO: replace with GroupedGemmKernelArgument
struct GemmBiasTransKernelArg
{
// pointers
......@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return str.str();
}
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
}
// polymorphic
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
void SetDeviceKernelArgs(BaseArgument* p_arg, void* kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->grouped_gemm_kernel_args_dev = kernel_args;
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t);
auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
if(arg_ptr)
{
return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
if(arg_ptr)
{
return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace;
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->p_workspace_ = p_workspace;
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_));
}
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
......@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// polymorphic
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
{
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->UpdateKBatch(k_batch);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->UpdateKBatch(kbatch);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
};
......
......@@ -538,10 +538,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return false;
}
if(std::is_same_v<EDataType, ck::bhalf_t> && arg.K_BATCH > 1 && !is_bf16_atomic_supported())
{
return false;
}
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)
{
......@@ -631,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
sizeof(GemmTransKernelArg);
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(p_arg_)
{
return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return GetWorkSpaceSize(p_arg);
}
// TODO: deperecation notice.
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
// polymorphic
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->UpdateKBatch(kbatch);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
};
......
......@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck {
......@@ -38,7 +40,7 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -62,7 +64,13 @@ __global__ void
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared_0,
p_shared_1,
karg,
karg.p_workspace_);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
p_c_grid{p_c_grid_},
block_2_ctile_map_streamk(
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
{
}
......@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
StreamKReductionStrategy::Atomic,
8,
4>
block_2_ctile_map_streamk;
};
struct SplitKBatchOffset
......@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MXdlPerWave / CShuffleMXdlPerWavePerShuffle>{},
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
Number<NXdlPerWave / CShuffleNXdlPerWavePerShuffle>{},
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
}
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
......@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto GetClusterLengthReduction()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
constexpr auto NPerBlockReduction =
NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock;
constexpr auto MPerBlockReduction =
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
}
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
{
const auto c_partial_acc_block_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(NPerBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(I1, MPerBlock));
}
}();
return c_partial_acc_block_m_n;
}
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
......@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
Problem& problem)
Problem& problem,
void* p_workspace)
{
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
......@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n =
make_multi_index(0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n =
make_multi_index(0,
0,
0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
acc_buf;
// start to compute
auto reduction_idx =
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
reduction_idx, problem.M, problem.N);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1);
__syncthreads();
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, // SrcData,
CDataType, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1,
1,
1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock),
CElementwiseOperation{}};
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
if(threadIdx.x == 0)
{
p_semaphore[reduction_idx] = 0;
}
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
continue;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
MPerBlock * NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
......@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
problem.K,
problem.KPadded,
problem.StrideA,
problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
problem.KPadded,
problem.N,
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
......@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
......@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave *
NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
make_tuple(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_shuffle_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
}
if constexpr(access_id < num_access - 1)
......@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
} // shuffle c and write-out end
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
}
} // while loop
} // for loop
}
template <bool HasMainKBlockLoop,
......@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
Problem& problem)
Problem& problem,
void* p_workspace)
{
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
Block2CTileMap_streamk block_2_ctile_map_streamk(
problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
......@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n =
make_multi_index(0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n =
make_multi_index(0,
0,
0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
acc_buf;
// start to compute
auto reduction_idx =
block_idx - block_2_ctile_map_streamk.reduction_start_block_idx;
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
reduction_idx, problem.M, problem.N);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1);
uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start;
if(threadIdx.x == 0)
{
p_semaphore[reduction_idx] = 0;
}
__syncthreads();
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, // SrcData,
CDataType, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1,
1,
1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock),
CElementwiseOperation{}};
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
if(threadIdx.x == 0)
{
atomicAdd(&p_semaphore[reduction_idx], 1);
}
wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count);
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
continue;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
MPerBlock * NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
......@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
problem.K,
problem.KPadded,
problem.StrideA,
problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
problem.KPadded,
problem.N,
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
......@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared_0),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
......@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave *
NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
make_tuple(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_shuffle_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
}
if constexpr(access_id < num_access - 1)
{
......@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
});
}
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(0);
}
}
}
}
......
......@@ -13,6 +13,11 @@ namespace ck {
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
......@@ -99,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx11__) || defined(__gfx12__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
......@@ -261,10 +266,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
......
......@@ -8,7 +8,6 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
......
......@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
......@@ -62,6 +63,7 @@
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
......
......@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
template <bool pre_nop>
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
template <bool pre_nop>
struct buffer_atomic_add<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
......@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
......@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
if constexpr(oob_conditional_check)
{
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
else
{
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
1);
}
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
......
......@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif
}
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
......
......@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif
}
template <typename T>
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
{
static_assert(sizeof(T) == 4);
// per-thread v_flag store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
: [s_exec_flag] "=s"(exec_flag)
: [v_flag] "v"(v_flag));
return exec_flag;
}
template <typename X, typename Y>
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
{
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
// per-thread cmp store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
: [s_exec_flag] "=s"(exec_flag)
: [v_x] "v"(x), [v_y] "v"(y));
return exec_flag;
}
} // namespace ck_tile
......@@ -64,6 +64,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
......@@ -225,3 +226,7 @@
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
......@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan,
truncate,
standard_asm,
rta_asm, // round to nearest away
};
template <bf16_rounding_mode rounding =
......@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f)
return uint16_t(u.int32);
}
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rta_asm(float f)
{
union
{
float fp32;
struct
{
uint16_t lo;
uint16_t hi;
};
} u = {f};
const uint32_t low_nan = 0x7fff;
const uint32_t hi_nan = 0x7fff0000;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
// Note: in above code snipet, we use hi 16 bit
return u.hi;
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
......@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
return float_to_bf16_rta_asm(f);
else
return float_to_bf16_truc_raw(f);
}
......
......@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
CK_TILE_DEVICE void update(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, linear_offset, is_valid_element, x);
this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add<X>(i, linear_offset, is_valid_element, x);
this->template atomic_add<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
this->template atomic_max<X>(i, linear_offset, is_valid_element, x);
this->template atomic_max<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
auto tmp =
this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
this->template set<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x + tmp);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update_raw(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
......@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
......@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add_raw<remove_cvref_t<T>,
t_per_x,
Coherence,
oob_conditional_check,
pre_nop>(
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
......
......@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
......@@ -51,15 +55,35 @@ template <typename DistributedTensor_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{});
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/**
......@@ -76,6 +100,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename T,
......@@ -95,6 +121,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
......@@ -114,6 +142,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
......@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
......@@ -134,6 +166,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
......@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
......@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
}
else
{
// NOT implemented
static_assert(false, "The shuffle should always happen!");
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment