Unverified Commit 8c4897d1 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-756

parents 9ba9ebec 9e86ebd6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
struct GridwiseGemmPipeline_v3
{
__host__ __device__ static constexpr bool IsSupported(index_t)
{
// TODO: improve applicability
return true;
}
template <typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
num_loop--;
while(num_loop > 0)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
num_loop--;
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
...@@ -164,8 +164,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -164,8 +164,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -318,8 +318,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -318,8 +318,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using ReduceGridDescriptor_MBlock_MPerBlock = using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>; remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
......
...@@ -375,10 +375,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -375,10 +375,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(AGridDesc_M_K{}, 1))>; remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(AGridDesc_M_K{}, 1))>;
using DefaultBGridDesc_BK0_N_BK1 = using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(BGridDesc_N_K{}, 1))>; remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(BGridDesc_N_K{}, 1))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( EGridDesc_M_N{}))>;
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap = using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}, 1))>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ADataType, // FIXME: don't assume A/B have same datatype
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(AK0PerBlock * Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
AK1,
I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(BK0PerBlock * Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
BK1,
I1));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max(a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType),
c_block_size * sizeof(CShuffleDataType));
}
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return math::integer_least_multiple(N, NPerBlock);
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch)
{
return math::integer_least_multiple(K, KPerBlock * K_Batch);
}
template <typename ALayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto MPad = CalculateMPadded(M);
const auto KPad = CalculateKPadded(K, KBatch);
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto AK0 = KPad / (KBatch * AK1);
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto NPad = CalculateNPadded(N);
const auto KPad = CalculateKPadded(K, KBatch);
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto BK0 = KPad / (KBatch * BK1);
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
// E desc for destination in blockwise copy
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy
template <typename DsGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
GemmSpecialization GemmSpec>
__host__ __device__ static constexpr bool
CheckValidity(const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const index_t KBatch)
{
const auto a_grid_desc_kbatch_ak0_m_ak1 =
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, KBatch);
ignore = StrideDs;
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
#if 0
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false;
}
#endif
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using DsGridPointer = decltype(MakeDsGridPointer());
template <typename ELayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t NumDTensor_,
typename DsDataType_,
typename AGridDesc_KBatch_AK0_M_AK1,
typename BGridDesc_KBatch_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEElementwiseOperation_,
typename Block2ETileMap>
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
uint32_t* barrier_count_finished,
const index_t KBatch,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation_& cde_element_op,
const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1,
const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor_>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_kbatch_ak0_m_ak1 =
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_kbatch_bk0_n_bk1 =
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ComputeType,
decltype(a_grid_desc_kbatch_ak0_m_ak1),
decltype(a_block_desc_kbatch_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_kbatch_ak0_m_ak1,
make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_kbatch_ak0_m_ak1,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
ComputeType,
decltype(b_grid_desc_kbatch_bk0_n_bk1),
decltype(b_block_desc_kbatch_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_kbatch_bk0_n_bk1,
make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_kbatch_bk0_n_bk1,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
#if 1
if(block_work_idx[I0] == 0)
{
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
const index_t numNThreads = NPerBlock / nThreadSize;
const index_t numMThreads = BlockSize / numNThreads;
const index_t mThreadSize = MPerBlock / numMThreads;
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
EDataType,
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
true>
e_thread_zero_buf;
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
EDataType,
EDataType,
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, mThreadSize, 1, nThreadSize>,
Sequence<0, 1, 2, 3>,
3,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set,
1,
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1],
m_tid * mThreadSize,
block_work_idx[I2],
n_tid * nThreadSize),
ck::tensor_operation::element_wise::PassThrough{}};
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
e_thread_zero_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
__syncthreads();
if(threadIdx.x == 0)
{
atomicAdd(barrier_count_finished, 1);
}
}
#endif
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop =
__builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1,
a_block_desc_kbatch_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_kbatch_bk0_n_bk1,
b_block_desc_kbatch_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
if(threadIdx.x == 0)
{
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
}
__syncthreads();
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor_>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor_>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0);
},
Number<NumDTensor_>{}));
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation_,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor_,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
cde_element_op};
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor_, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
if(threadIdx.x == 0)
{
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
if(k_id_finished_t == KBatch)
{
*barrier_count_finished = 0;
}
}
}
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap>
__device__ static void Run(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
uint32_t* barrier_count_finished,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const index_t KBatch,
const Block2ETileMap& block_2_etile_map)
{
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
DsGridDesc_M_N ds_grid_desc_m_n;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_kbatch_ak0_m_ak1 =
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, KBatch);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
if(kbatch_id == KBatch - 1)
{
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
barrier_count_finished,
KBatch,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
else
{
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
barrier_count_finished,
KBatch,
a_element_op,
b_element_op,
ck::tensor_operation::element_wise::PassThrough{},
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
}
};
} // namespace ck
...@@ -138,8 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -138,8 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -308,8 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -308,8 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
......
...@@ -35,13 +35,17 @@ __global__ void ...@@ -35,13 +35,17 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop> template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
...@@ -61,7 +65,8 @@ __global__ void ...@@ -61,7 +65,8 @@ __global__ void
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
...@@ -102,7 +107,8 @@ template <typename ALayout, ...@@ -102,7 +107,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = FloatC>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem struct Argument : public tensor_operation::device::BaseArgument, public Problem
{ {
__host__ Argument(const FloatAB* p_a_grid_, __host__ Argument(const FloatA* p_a_grid_,
const FloatAB* p_b_grid_, const FloatB* p_b_grid_,
FloatC* p_c_grid_, FloatC* p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
...@@ -479,14 +485,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -479,14 +485,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
} }
const FloatAB* p_a_grid; const FloatA* p_a_grid;
const FloatAB* p_b_grid; const FloatB* p_b_grid;
FloatC* p_c_grid; FloatC* p_c_grid;
}; };
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned * sizeof(ComputeType) +
sizeof(FloatAB), b_block_space_size_aligned * sizeof(ComputeType)),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(FloatCShuffle));
} }
...@@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const Problem& problem) const Problem& problem)
...@@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<AK0Number, MPerBlock, AK1Number>, Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatAB, ComputeType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<BK0Number, NPerBlock, BK1Number>, Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, ComputeType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, ComputeType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......
...@@ -173,8 +173,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -173,8 +173,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -345,8 +345,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -345,8 +345,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using C0GridDescriptor_NBlock_NPerBlock = using C0GridDescriptor_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>; remove_cvref_t<decltype(MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>;
......
...@@ -330,8 +330,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -330,8 +330,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
return e_grid_desc_mblock_mperblock_nblock_nperblock; return e_grid_desc_mblock_mperblock_nblock_nperblock;
} }
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
using DefaultBlock2ETileMap = using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
......
...@@ -259,8 +259,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -259,8 +259,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// denorm test fix, required to work around fp16 mfma issue // denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck {
template <typename GridwiseGemm>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid,
typename GridwiseGemm::FloatC* p_c_grid,
void* p_workspace,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
typename GridwiseGemm::Block2CTileMap block_mapping)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_workspace,
M,
N,
K,
StrideA,
StrideB,
StrideC,
block_mapping,
static_cast<void*>(p_shared));
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_workspace;
ignore = M;
ignore = N;
ignore = K;
ignore = StrideA;
ignore = StrideB;
ignore = StrideC;
ignore = block_mapping;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <index_t BlockSize,
typename Block2CTileMap_,
typename FloatAB_,
typename FloatAcc_,
typename FloatC_,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
static constexpr auto KPerBlock = K0PerBlock * K1;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using FloatAcc = FloatAcc_;
using FloatCShuffle = FloatAcc;
using Block2CTileMap = Block2CTileMap_;
using FloatAB = FloatAB_;
using FloatC = FloatC_;
struct Argument : public ck::tensor_operation::device::BaseArgument
{
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
Block2CTileMap block_mapping;
Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
uint32_t num_cu,
uint32_t occupancy,
uint32_t num_sk_blocks_)
: p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_),
p_c_grid(p_c_grid_),
M(M_),
N(N_),
K(K_),
StrideA(StrideA_),
StrideB(StrideB_),
StrideC(StrideC_),
block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_)
{
}
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << std::endl;
}
};
__host__ __device__ static auto CalculateGridSize(const Argument& karg)
{
return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
math::integer_divide_ceil(karg.M, MPerBlock),
karg.k_batch);
}
__host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; }
__host__ __device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
{
const index_t K0 = CalculateK0(KPad);
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
__host__ __device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
{
const index_t K0 = CalculateK0(KPad);
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * K1, K1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * K1, K1, I1));
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle().GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
}
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
return false;
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
return false;
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
return false;
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
return false;
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
return false;
}
else
{
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = K0 > K0PerBlock;
return has_main_k0_block_loop;
}
template <typename CGridDesc>
__host__ __device__ static constexpr auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
return transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc>(
c_m_n_grid_desc, 8, KBatch);
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
return make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWave * MPerXDL>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat / CShuffleMRepeatPerShuffle>{},
Number<CShuffleMRepeatPerShuffle * MWave * MPerXDL>{},
Number<NRepeat / CShuffleNRepeatPerShuffle>{},
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
__host__ __device__ static constexpr auto GetClusterLengthReduction()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
constexpr auto NPerBlockReduction =
NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
constexpr auto MPerBlockReduction =
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
}
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
{
const auto c_partial_acc_block_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(NPerBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(I1, MPerBlock));
}
}();
return c_partial_acc_block_m_n;
}
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
__device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
void* p_workspace,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
Block2CTileMap block_mapping,
void* __restrict__ p_shared_block)
{
uint32_t m = M;
uint32_t n = N;
uint32_t k = K;
uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock;
uint32_t stride_a = StrideA;
uint32_t stride_b = StrideB;
uint32_t stride_c = StrideC;
const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const AElementwiseOperation a_element_op = AElementwiseOperation{};
const BElementwiseOperation b_element_op = BElementwiseOperation{};
const CElementwiseOperation c_element_op = CElementwiseOperation{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block);
FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
uint32_t block_idx = block_mapping.get_block_idx();
bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx;
bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
block_idx < block_mapping.dp_start_block_idx;
uint32_t iter_start, iter_end;
block_mapping.get_block_itr(block_idx, iter_start, iter_end);
uint32_t total_iter_length = iter_end - iter_start;
if(is_padding_block)
return;
uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters =
math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n = make_multi_index(
0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_n_reverse =
make_multi_index(0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n = make_multi_index(
0,
0,
0,
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_n_reverse =
make_multi_index(0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
CBlockTransferScalarPerVector_NWaveNPerXDL,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
CBlockTransferScalarPerVector_NWaveNPerXDL,
true>
acc_buf;
// start to compute
auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
FloatAcc, // SrcData,
FloatAcc, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CBlockTransferScalarPerVector_NWaveNPerXDL)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, // SrcData,
FloatC, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1, 1, 1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CBlockTransferScalarPerVector_NWaveNPerXDL),
CElementwiseOperation{}};
// block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<FloatAcc*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
return;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
uint32_t tile_idx, iter_offset;
block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock);
const index_t k0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_grid_desc),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_grid_desc),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
const index_t num_k_block_main_loop = current_iter_length;
gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
a_block_desc_k0_m_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_k0_n_k1_grid_desc,
b_block_desc_k0_n_k1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// output: register to global memory
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<FloatCShuffle*>(p_shared_block),
c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<FloatAcc*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(
make_tuple(CShuffleMRepeatPerShuffle,
M1,
M2,
M3,
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform(I0), // freeze nblock
make_unmerge_transform(
make_tuple(CShuffleNRepeatPerShuffle,
N1,
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{},
Sequence<0, 2, 4, 5, 6>{},
Sequence<>{},
Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatCShuffle,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// LDS to global
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
0,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatCShuffle, // typename DstData,
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false,
// othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false,
// othre wise has scratch
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
constexpr auto nxdlperwave_forward_step =
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
constexpr auto nxdlperwave_backward_step =
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nxdlperwave_value =
nxdlperwave_forward_sweep
? nxdlperwave_iter
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure it's safe to do ds_read
block_sync_lds();
c_block_copy_lds_to_global.SetSrcSliceOrigin(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(0, 0, 0, 0));
// LDS to global
if(is_dp_block)
c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
else if(is_sk_block)
{
if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
else if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
c_block_copy_lds_to_global
.template Run<decltype(c_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
}
// move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep &&
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_forward_step);
}
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_backward_step);
}
});
// move on mxdlperwave dimension
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
mxdlperwave_forward_step);
}
});
if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
}
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
}
template <typename Layout>
struct LStr
{
static std::string Get() { return ""; }
};
template <>
struct LStr<ck::tensor_layout::gemm::RowMajor>
{
static std::string Get() { return "R"; }
};
template <>
struct LStr<ck::tensor_layout::gemm::ColumnMajor>
{
static std::string Get() { return "C"; }
};
static std::string GetTypeString()
{
auto str = std::stringstream();
// clang-format off
str << "GemmXdlStreamK_"
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< "_"
<< "B" << BlockSize << "_"
<< "Vec" << ABlockTransferSrcScalarPerVector << "x"
<< BBlockTransferSrcScalarPerVector << "x"
<< CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
<< MPerBlock << "x"
<< NPerBlock << "x"
<< K0PerBlock << "x"
<< K1 ;
// clang-format on
return str.str();
}
};
} // namespace ck
...@@ -27,6 +27,9 @@ template <typename GridwiseGemm, ...@@ -27,6 +27,9 @@ template <typename GridwiseGemm,
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
#endif #endif
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop> ...@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
#endif #endif
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
{ {
...@@ -241,8 +247,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -241,8 +247,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatC* p_c_grid; FloatC* p_c_grid;
}; };
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue // denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
......
...@@ -45,7 +45,8 @@ __global__ void ...@@ -45,7 +45,8 @@ __global__ void
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
typename ALayout, typename ALayout,
...@@ -85,7 +86,8 @@ template <index_t BlockSize, ...@@ -85,7 +86,8 @@ template <index_t BlockSize,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL, index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = FloatC>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -108,13 +110,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -108,13 +110,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
struct Argument : public ck::tensor_operation::device::BaseArgument struct Argument : public ck::tensor_operation::device::BaseArgument
{ {
const FloatAB* p_a_grid; const FloatA* p_a_grid;
const FloatAB* p_b_grid; const FloatB* p_b_grid;
FloatC* p_c_grid; FloatC* p_c_grid;
index_t M; index_t M;
index_t N; index_t N;
...@@ -128,8 +130,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -128,8 +130,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t K0; index_t K0;
index_t k_batch; index_t k_batch;
Argument(const FloatAB* p_a_grid_, Argument(const FloatA* p_a_grid_,
const FloatAB* p_b_grid_, const FloatB* p_b_grid_,
FloatC* p_c_grid_, FloatC* p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
...@@ -365,7 +367,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -365,7 +367,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto c_block_size = constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType),
c_block_size * sizeof(FloatC)); c_block_size * sizeof(FloatC));
} }
...@@ -577,8 +579,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -577,8 +579,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
void* __restrict__ p_shared_block, void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const FloatAB* p_a_grid = karg.p_a_grid; const FloatA* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid; const FloatB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid; FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
...@@ -698,8 +700,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -698,8 +700,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Sequence<1, K0PerBlock, MPerBlock, K1>, Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatAB, ComputeType,
decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -728,8 +730,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -728,8 +730,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Sequence<1, K0PerBlock, NPerBlock, K1>, Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, ComputeType,
decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -759,7 +761,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -759,7 +761,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, ComputeType,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
...@@ -776,8 +778,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -776,8 +778,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block); ComputeType* p_a_block = static_cast<ComputeType*>(p_shared_block);
FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size; ComputeType* p_b_block = static_cast<ComputeType*>(p_shared_block) + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
...@@ -787,53 +789,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -787,53 +789,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
#if 0
// preload data into LDS
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (karg.K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#else
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
...@@ -856,7 +811,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -856,7 +811,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
#endif
// output: register to global memory // output: register to global memory
{ {
......
...@@ -139,8 +139,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -139,8 +139,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -315,8 +315,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -315,8 +315,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>; CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -634,10 +634,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -634,10 +634,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename SrcData,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype( decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim, 5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
......
...@@ -142,8 +142,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -142,8 +142,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -323,13 +323,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -323,13 +323,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>; CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C0GridDesc_M_N{}))>; C0GridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -654,12 +654,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -654,12 +654,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
FloatC, // typename Src0Data, FloatC, // typename Src0Data,
FloatC, // typename Src1Data, FloatC, // typename Src1Data,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype( decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype( decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim, 5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
......
...@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>; CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C0GridDesc_M_N{}))>; C0GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C1GridDesc_M_N{}))>; C1GridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
FloatC, // typename Src1Data, FloatC, // typename Src1Data,
FloatC, // typename Src2Data, FloatC, // typename Src2Data,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype( decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype( decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim, 5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple( using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple(
Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{}))); Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{})));
using ThreadwiseWolfordDescReduce = decltype( using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<DimSubBlocks * DimThreadSize>{}))); make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
using ThreadwiseWelford = using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>; ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>;
......
...@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st ...@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{}; static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static int __device__ static int
GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
{ {
bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
if(is_rightmost_block) if(is_rightmost_block)
{ {
int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock; int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
int kPerThread = int kPerThread = kRightmostBlock < K_BlockTileSize
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); ? 0
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0) if(kPerBlockTail > 0)
{ {
...@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st ...@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st
} }
else else
{ {
int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); int kPerBlock = math::integer_divide_ceil(k, kGridSize);
return KThreadSliceSize * (kPerBlock / K_BlockTileSize); return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
} }
} }
...@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st ...@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st
auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford(); auto threadwise_welford = ThreadwiseWelford();
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
threadwise_welford.max_count_ = threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id); kRaw,
k_grid_size,
block_k_cluster_id,
thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f); mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace ck {
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
bool ShareA,
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t TK0 = TKLengths{}[I0];
static constexpr index_t TK1 = TKLengths{}[I1];
static constexpr index_t TM0 = TMLengths{}[I0];
static constexpr index_t TM1 = TMLengths{}[I1];
static constexpr index_t TN0 = TNLengths{}[I0];
static constexpr index_t TN1 = TNLengths{}[I1];
static_assert(TM0 == 1 && TN0 == 1);
static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) ||
(!ShareA && TN1 % dpp8::lane_group_size == 0));
static constexpr index_t shared_elems_per_lane =
ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size;
__device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1;
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1));
constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane;
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1));
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(0, tm1, 0, tn1));
constexpr int src_lane =
ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size
: (tn1 / shared_elems_per_lane) % dpp8::lane_group_size;
dpp8::inner_product_dpp<a_vector_t, b_vector_t, FloatC, src_lane, ShareA>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
});
});
});
}
};
} // namespace ck
...@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment