Commit 87a75734 authored by Jing Zhang's avatar Jing Zhang
Browse files

adding xdlops

parent 7972ab17
#ifndef CK_GRIDWISE_GROUP_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_GROUP_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "gridwise_gemm_xdlops_fp16_bfp16.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t G,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPack,
class GemmABlockCopyThreadSliceLengths_GemmG_GemmK_GemmM_GemmKPack,
class GemmABlockCopyThreadClusterLengths_GemmG_GemmK_GemmM_GemmKPack,
class GemmABlockCopyThreadClusterArrangeOrder,
class GemmABlockCopySrcAccessOrder,
class GemmABlockCopyDstAccessOrder,
index_t GemmABlockCopySrcDataPerRead_GemmKPack,
index_t GemmABlockCopyDstDataPerWrite_GemmKPack,
class GemmBBlockCopyThreadSliceLengths_GemmG_GemmK_GemmN_GemmKPack,
class GemmBBlockCopyThreadClusterLengths_GemmG_GemmK_GemmN_GemmKPack,
class GemmBBlockCopyThreadClusterArrangeOrder,
class GemmBBlockCopySrcAccessOrder,
class GemmBBlockCopyDstAccessOrder,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmKPack,
WorkgroupScheduleOrder WorkgroupSchdOrder>
struct GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
{
__device__ void Run(const ABFloat* const __restrict__ p_in_global,
const ABFloat* const __restrict__ p_wei_global,
CFloat* const __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_cpergroup_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_cpergroup_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_cpergroup_y_x_global_desc.GetLengths()[3];
constexpr index_t CPerGroup = C / G;
constexpr index_t KPerGroup = K / G;
static_assert(CPerGroup == wei_k_cpergroup_y_x_global_desc.GetLengths()[1], "wrong!");
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GemmG = G;
constexpr index_t GemmM = KPerGroup;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GemmKTotal = CPerGroup * Y * X;
static_assert(GemmKTotal % GemmKPack == 0,
"wrong! GemmKTotal should be multiple of GemmKPack");
constexpr index_t GemmK = GemmKTotal / GemmKPack;
static_assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0,
"wrong! cannot divide work evenly among block");
// construct tensor descriptor for group convolution
constexpr auto in_g_n_cpergroup_hi_wi_global_desc = make_native_tensor_descriptor(
Sequence<G, N, CPerGroup, Hi, Wi>{},
Sequence<CPerGroup * Hi * Wi, C * Hi * Wi, Hi * Wi, Wi, 1>{});
constexpr auto wei_g_kpergroup_cpergroup_y_x_global_desc =
make_native_tensor_descriptor_packed(Sequence<G, KPerGroup, CPerGroup, Y, X>{});
constexpr auto out_g_n_kpergroup_ho_wo_global_desc = make_native_tensor_descriptor(
Sequence<G, N, KPerGroup, Ho, Wo>{},
Sequence<KPerGroup * Ho * Wo, K * Ho * Wo, Ho * Wo, Wo, 1>{});
// input tensor
constexpr auto in_g_n_cpergroup_hip_wip_global_desc = transform_tensor_descriptor(
in_g_n_cpergroup_hi_wi_global_desc,
make_tuple(PassThrough<G>{},
PassThrough<N>{},
PassThrough<CPerGroup>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
constexpr index_t Hip = in_g_n_cpergroup_hip_wip_global_desc.GetLengths()[3];
constexpr index_t Wip = in_g_n_cpergroup_hip_wip_global_desc.GetLengths()[4];
constexpr auto in_g_n_cpergroup_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_g_n_cpergroup_hip_wip_global_desc,
make_tuple(PassThrough<G>{},
PassThrough<N>{},
PassThrough<CPerGroup>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}));
constexpr auto in_gemmg_gemmktotal_gemmn_global_desc = transform_tensor_descriptor(
in_g_n_cpergroup_y_ho_x_wo_global_desc,
make_tuple(PassThrough<G>{}, Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<0>{}, Sequence<2, 3, 5>{}, Sequence<1, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto in_gemmg_gemmk_gemmn_gemmkpack_global_desc = transform_tensor_descriptor(
in_gemmg_gemmktotal_gemmn_global_desc,
make_tuple(
PassThrough<GemmG>{}, UnMerge<Sequence<GemmK, GemmKPack>>{}, PassThrough<GemmN>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
// weight tensor
constexpr auto wei_gemmg_gemmm_gemmktotal_global_desc = unfold_tensor_descriptor(
wei_g_kpergroup_cpergroup_y_x_global_desc, Number<2>{}, Number<4>{});
constexpr auto wei_gemmg_gemmk_gemmm_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmg_gemmm_gemmktotal_global_desc,
make_tuple(
PassThrough<GemmG>{}, PassThrough<GemmM>{}, UnMerge<Sequence<GemmK, GemmKPack>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1, 3>{}));
// output tensor
constexpr auto out_gemmg_gemmm_gemmn_global_desc = transform_tensor_descriptor(
out_g_n_kpergroup_ho_wo_global_desc,
make_tuple(PassThrough<G>{}, PassThrough<KPerGroup>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// gridwise batch-GEMM
constexpr auto gridwise_gemm = GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2<
GridSize,
BlockSize,
ABFloat,
AccFloat,
CFloat,
decltype(wei_gemmg_gemmk_gemmm_gemmkpack_global_desc),
decltype(in_gemmg_gemmk_gemmn_gemmkpack_global_desc),
decltype(out_gemmg_gemmm_gemmn_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmABlockCopyThreadSliceLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyThreadClusterLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyThreadClusterArrangeOrder,
GemmABlockCopySrcAccessOrder,
GemmABlockCopyDstAccessOrder,
3, // src vector read dimension of A matrix is GemmKPack
GemmABlockCopySrcDataPerRead_GemmKPack,
GemmABlockCopyDstDataPerWrite_GemmKPack,
GemmBBlockCopyThreadSliceLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyThreadClusterLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyThreadClusterArrangeOrder,
GemmBBlockCopySrcAccessOrder,
GemmBBlockCopyDstAccessOrder,
2, // Src vetor read diemsnion of B matrix is GemmN
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPack,
InMemoryDataOperation::Set,
WorkgroupSchdOrder>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class Float,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA, // \todo unused parameter, remove
index_t GemmDataPerReadB // \todo unused parameter, remove
>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct MatrixIndex
{
index_t row;
index_t col;
};
#if CK_WORKAROUND_SWDEV_241664
static constexpr index_t MRepeats = (GemmMPerWave > 64) ? (GemmMPerWave / 64) : 1;
static constexpr index_t NRepeats = (GemmNPerWave > 64) ? (GemmNPerWave / 64) : 1;
static constexpr index_t MPerXdlops = (GemmMPerWave > 64) ? 64 : GemmMPerWave;
static constexpr index_t NPerXdlops = (GemmNPerWave > 64) ? 64 : GemmNPerWave;
static constexpr auto XdlopsGemm =
XdlopsGemm_t<Float, MPerXdlops, NPerXdlops, GemmDataPerReadA, GemmDataPerReadB>{};
#else
#if CK_USE_AMD_XDLOPS_INLINE_ASM
/// \to-do add inline support for vector type c
static_assert(false, "Does not support inline asm for vector type c")
#else
static constexpr auto XdlopsGemm =
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
#endif
#endif
index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB;
static constexpr index_t WaveSize = 64;
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm.GetOutputLayout(); }
#if CK_WORKAROUND_SWDEV_241664
template <index_t MRepeats_ = MRepeats, index_t NRepeats_ = NRepeats>
__device__ constexpr auto CreateOutputVecZero() const;
template <>
__device__ constexpr auto CreateOutputVecZero<2, 1>() const
{
return c_vec32_2_2_t::CreateVecZero();
}
template <>
__device__ constexpr auto CreateOutputVecZero<1, 2>() const
{
return c_vec32_2_2_t::CreateVecZero();
}
template <>
__device__ constexpr auto CreateOutputVecZero<1, 1>() const
{
return XdlopsGemm.GetOutputLayout().CreateOutputVecZero();
}
#else
__device__ constexpr auto CreateOutputVecZero() const
{
return XdlopsGemm.GetOutputLayout().CreateOutputVecZero();
}
#endif
__device__ constexpr auto GetNumBlks() const
{
#if CK_WORKAROUND_SWDEV_241664
return XdlopsGemm.GetOutputLayout().GetNumBlks() * MRepeats * NRepeats;
#else
return XdlopsGemm.GetOutputLayout().GetNumBlks();
#endif
}
__device__ constexpr auto GetBlkSize() const
{
return XdlopsGemm.GetOutputLayout().GetBlkSize();
}
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops()
{
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
static_assert(GemmMPerWave * GemmMWaves == M, "GemmMWaves * GemmMPerWave != M");
static_assert(GemmNPerWave * GemmNWaves == N, "GemmNWaves * GemmNPerWave != N");
static_assert(BlockSize == GemmMWaves * GemmNWaves * WaveSize,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize\n");
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const index_t waveId_m = waveId / GemmNWaves;
const index_t waveId_n = waveId % GemmNWaves;
mMyWaveOffsetA = waveId_m * GemmMPerWave;
mMyWaveOffsetB = waveId_n * GemmNPerWave;
}
#if CK_WORKAROUND_SWDEV_241664
template <index_t MRepeats_, index_t NRepeats_>
struct WithMNRepeats;
template <>
struct WithMNRepeats<2, 1>
{
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ static FloatC Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread)
{
p_c_thread.s.x.l =
XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread.s.x.l);
p_c_thread.s.y.l = XdlopsGemm.template Run<M, N, K>(
p_a_block + MPerXdlops, p_b_block, p_c_thread.s.y.l);
return p_c_thread;
}
};
template <>
struct WithMNRepeats<1, 2>
{
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ static FloatC Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread)
{
p_c_thread.s.x.l =
XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread.s.x.l);
p_c_thread.s.y.l = XdlopsGemm.template Run<M, N, K>(
p_a_block, p_b_block + NPerXdlops, p_c_thread.s.y.l);
return p_c_thread;
}
};
template <>
struct WithMNRepeats<1, 1>
{
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ static FloatC Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread)
{
return XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread);
}
};
#endif
template <class FloatA, class FloatB, class FloatC>
__device__ FloatC Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread) const
{
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow();
#if CK_WORKAROUND_SWDEV_241664
return WithMNRepeats<MRepeats, NRepeats>::template Run<M, N, K>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
#else
return XdlopsGemm.template Run<M, N, K>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
#endif
}
template <index_t AStride = GemmMPerWave, index_t BStride = GemmNPerWave>
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
#if CK_WORKAROUND_SWDEV_241664
const index_t xdlops_i = i / XdlopsGemm.GetOutputLayout().GetNumBlks();
const index_t j = i % XdlopsGemm.GetOutputLayout().GetNumBlks();
const index_t m = xdlops_i / NRepeats;
const index_t n = xdlops_i % NRepeats;
const auto thread_mtx_on_blk = XdlopsGemm.GetBeginOfThreadBlk(j);
const index_t col =
(waveId % GemmNWaves) * BStride + n * NPerXdlops + thread_mtx_on_blk.col;
const index_t row =
(waveId / GemmNWaves) * AStride + m * MPerXdlops + thread_mtx_on_blk.row;
#else
const auto thread_mtx_on_blk = XdlopsGemm.GetBeginOfThreadBlk(i);
const index_t col = (waveId % GemmNWaves) * BStride + thread_mtx_on_blk.col;
const index_t row = (waveId / GemmNWaves) * AStride + thread_mtx_on_blk.row;
#endif
return MatrixIndex{row, col};
}
__device__ constexpr auto GetThreadMatrixCDescriptor() const
{
const index_t total_reg_size = GemmMPerWave * GemmNPerWave / WaveSize;
return make_ConstantMatrixDescriptor_packed(Number<total_reg_size>{}, Number<1>{});
}
__device__ void XdlopsMatrixCSetZero() const { XdlopsGemm.SetZeroXdlopsRegs(); }
template <class FloatC>
__device__ void XdlopsMatrixCRead(FloatC* __restrict__ p_c_thread) const
{
XdlopsGemm.ReadXdlopsRegs(p_c_thread);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy_v2.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
namespace ck {
enum WorkgroupScheduleOrder
{
MBlock1NBlock0,
NBlock1MBlock0
};
template <index_t Gi,
index_t MBlockWork,
index_t NBlockWork,
WorkgroupScheduleOrder WorkgroupSchdOrder>
struct make_batch_block_work_sequence;
template <index_t Gi, index_t MBlockWork, index_t NBlockWork>
struct make_batch_block_work_sequence<Gi, MBlockWork, NBlockWork, MBlock1NBlock0>
{
__device__ constexpr auto get() { return Sequence<Gi, MBlockWork, NBlockWork>{}; }
};
template <index_t Gi, index_t MBlockWork, index_t NBlockWork>
struct make_batch_block_work_sequence<Gi, MBlockWork, NBlockWork, NBlock1MBlock0>
{
__device__ constexpr auto get() { return Sequence<Gi, NBlockWork, MBlockWork>{}; }
};
template <index_t MBlockWork, index_t NBlockWork, WorkgroupScheduleOrder WorkgroupSchdOrder>
struct make_block_work_sequence;
template <index_t MBlockWork, index_t NBlockWork>
struct make_block_work_sequence<MBlockWork, NBlockWork, MBlock1NBlock0>
{
__device__ constexpr auto get() { return Sequence<MBlockWork, NBlockWork>{}; }
};
template <index_t MBlockWork, index_t NBlockWork>
struct make_block_work_sequence<MBlockWork, NBlockWork, NBlock1MBlock0>
{
__device__ constexpr auto get() { return Sequence<NBlockWork, MBlockWork>{}; }
};
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_K_M_KPACK,
class ABlockCopyThreadClusterLengths_K_M_KPACK,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_KPACK,
class BBlockCopyThreadSliceLengths_K_N_KPACK,
class BBlockCopyThreadClusterLengths_K_N_KPACK,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_KPACK,
InMemoryDataOperation OutputMemOp,
WorkgroupScheduleOrder WorkgroupSchdOrder,
index_t ABlockCopySrcDataStride = 1,
index_t BBlockCopySrcDataStride = 1>
struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__ void Run(const ABFloat* const __restrict__ p_a_global,
const ABFloat* const __restrict__ p_b_global,
CFloat* const __restrict__ p_c_global) const
{
constexpr auto b_k_n_kpack_global_desc = BGlobalDesc{};
constexpr auto a_k_m_kpack_global_desc = AGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto K = b_k_n_kpack_global_desc.GetLengths()[0];
constexpr auto N = b_k_n_kpack_global_desc.GetLengths()[1];
constexpr auto M = a_k_m_kpack_global_desc.GetLengths()[1];
constexpr auto KPACK = b_k_n_kpack_global_desc.GetLengths()[2];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_sequence =
make_block_work_sequence<MBlockWork, NBlockWork, WorkgroupSchdOrder>{}.get();
constexpr auto block_work_desc = make_cluster_descriptor(block_work_sequence);
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[0] * MPerBlock)
: (block_work_id[1] * MPerBlock);
const index_t b_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[1] * NPerBlock)
: (block_work_id[0] * NPerBlock);
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_KPACK,
ABlockCopyDstDataPerWrite_KPACK,
KPACK * GemmDataPerReadM,
KPACK * GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock, KPACK>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_k_m_kpack_global_desc),
decltype(a_k_m_kpack_block_desc),
decltype(a_k_m_kpack_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M_KPACK,
ABlockCopyThreadClusterLengths_K_M_KPACK,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim, // Src dim to be read in vector form (M dimension)
2, // Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set,
ABlockCopySrcDataStride>({0, k_block_data_on_global, 0}, {0, 0, 0});
constexpr auto b_k_n_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock, KPACK>{}, Number<max_align>{});
// input blockwise copy
auto b_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(b_k_n_kpack_global_desc),
decltype(b_k_n_kpack_block_desc),
decltype(b_k_n_kpack_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N_KPACK,
BBlockCopyThreadClusterLengths_K_N_KPACK,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim, // Src dim to be read in vector form (N dimension)
2, // Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set,
BBlockCopySrcDataStride>({0, b_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<MPerBlock>{});
constexpr auto b_k_n_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<NPerBlock>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
ABFloat,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_kpack_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_kpack_block_desc.GetElementSpace(), max_align);
__shared__ ABFloat p_a_block_double[2 * a_block_space];
__shared__ ABFloat p_b_block_double[2 * b_block_space];
// get zero-initialized output register of vector type
auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using blockwise_a_copy_src_step = Sequence<KPerBlock, 0, 0>;
using blockwise_b_copy_src_step = Sequence<KPerBlock, 0, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
ABFloat* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
ABFloat* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
ABFloat* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
ABFloat* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how fp16/bfloat16 datatypes are
// processed in gemm operation. fp16 type packs 4 fp16 values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single fp16 to 4 packed fp16/2 packed bfloat16
// respectively.
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_now);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_now);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double + a_block_space);
p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double + b_block_space);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
}
// copy output: register to global memory
{
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t K0 = OutputLayout.M1();
constexpr index_t K1 = OutputLayout.N1();
constexpr index_t K2 = OutputLayout.M0();
constexpr auto out_k0_k1_k2_b_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<K0, K1, K2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// src descriptor
constexpr auto out_k0_k1_k2_b_thread_desc =
make_native_tensor_descriptor_packed(Sequence<K0, 1, K2, 1>{});
using OutThreadCopySliceLengths = Sequence<K0, 1, K2, 1>;
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc),
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
1,
1,
AddressSpace::Vgpr,
is_same<AccFloat, CFloat>::value ? AddressSpace::Global : AddressSpace::Generic,
OutputMemOp>({0, 0, 0, 0},
{k_thread_data_on_global / (K2 * K1),
k_thread_data_on_global % (K2 * K1) / K2,
k_thread_data_on_global % K2,
b_thread_data_on_global})
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
}
}
};
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_KPACK,
class BBlockCopyThreadSliceLengths_G_K_N_KPACK,
class BBlockCopyThreadClusterLengths_G_K_N_KPACK,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_KPACK,
InMemoryDataOperation OutputMemOp,
WorkgroupScheduleOrder WorkgroupSchdOrder,
index_t ABlockCopySrcDataStride = 1,
index_t BBlockCopySrcDataStride = 1>
struct GridwiseBatchedGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__ void Run(const ABFloat* const __restrict__ p_a_global,
const ABFloat* const __restrict__ p_b_global,
CFloat* const __restrict__ p_c_global) const
{
constexpr auto a_g_k_m_kpack_global_desc = AGlobalDesc{};
constexpr auto b_g_k_n_kpack_global_desc = BGlobalDesc{};
constexpr auto c_g_m_n_global_desc = CGlobalDesc{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto Gi = b_g_k_n_kpack_global_desc.GetLengths()[0];
constexpr auto Go = c_g_m_n_global_desc.GetLengths()[0];
constexpr auto K = b_g_k_n_kpack_global_desc.GetLengths()[1];
constexpr auto N = b_g_k_n_kpack_global_desc.GetLengths()[2];
constexpr auto M = a_g_k_m_kpack_global_desc.GetLengths()[2];
constexpr auto KPACK = b_g_k_n_kpack_global_desc.GetLengths()[3];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_sequence =
make_batch_block_work_sequence<Gi, MBlockWork, NBlockWork, WorkgroupSchdOrder>{}.get();
constexpr auto block_work_desc = make_cluster_descriptor(block_work_sequence);
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t group_id = block_work_id[0];
const index_t m_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[1] * MPerBlock)
: (block_work_id[2] * MPerBlock);
const index_t n_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[2] * NPerBlock)
: (block_work_id[1] * NPerBlock);
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_KPACK,
ABlockCopyDstDataPerWrite_KPACK,
KPACK * GemmDataPerReadM,
KPACK * GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_g_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, MPerBlock, KPACK>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_g_k_m_kpack_global_desc),
decltype(a_g_k_m_kpack_block_desc),
decltype(a_g_k_m_kpack_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK,
ABlockCopyThreadClusterLengths_G_K_M_KPACK,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim, // Src dim to be read in vector form (K dimension)
3, // Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set,
ABlockCopySrcDataStride>({group_id, 0, m_block_data_on_global, 0}, {0, 0, 0, 0});
constexpr auto b_g_k_n_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, NPerBlock, KPACK>{}, Number<max_align>{});
// input blockwise copy
auto b_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(b_g_k_n_kpack_global_desc),
decltype(b_g_k_n_kpack_block_desc),
decltype(b_g_k_n_kpack_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_G_K_N_KPACK,
BBlockCopyThreadClusterLengths_G_K_N_KPACK,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim, // Src dim to be read in vector form (K dimension)
3, // Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead, // N dimension
BBlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set,
BBlockCopySrcDataStride>({group_id, 0, n_block_data_on_global, 0}, {0, 0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<MPerBlock>{});
constexpr auto b_k_n_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<NPerBlock>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
ABFloat,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr index_t a_block_space =
math::integer_least_multiple(a_g_k_m_kpack_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_g_k_n_kpack_block_desc.GetElementSpace(), max_align);
__shared__ ABFloat p_a_block_double[2 * a_block_space];
__shared__ ABFloat p_b_block_double[2 * b_block_space];
// get zero-initialized output register of vector type
auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>;
using blockwise_b_copy_src_step = Sequence<0, KPerBlock, 0, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
ABFloat* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
ABFloat* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
ABFloat* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
ABFloat* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how fp16/bfloat16 datatypes are
// processed in gemm operation. fp16 type packs 4 fp16 values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single fp16 to 4 packed fp16/2 packed bfloat16
// respectively.
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_now);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_now);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double + a_block_space);
p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double + b_block_space);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_g_m0_m1_m2_n_global_desc = transform_tensor_descriptor(
c_g_m_n_global_desc,
make_tuple(PassThrough<Go>{}, UnMerge<Sequence<M0, M1, M2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
// src descriptor
constexpr auto c_g_m0_m1_m2_n_thread_desc =
make_native_tensor_descriptor_packed(Sequence<1, M0, 1, M2, 1>{});
using CThreadCopySliceLengths = Sequence<1, M0, 1, M2, 1>;
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
OutputMemOp>(
{0, 0, 0, 0, 0},
{group_id,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
}
}
};
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
class ABlockCopyThreadSliceLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_KPACK,
class BBlockCopyThreadSliceLengths_G_K_N_KPACK,
class BBlockCopyThreadClusterLengths_G_K_N_KPACK,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_KPACK,
InMemoryDataOperation CGlobalMemoryOp,
WorkgroupScheduleOrder WorkgroupSchdOrder>
struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
{
__device__ void Run(const ABFloat* const __restrict__ p_a_global,
const ABFloat* const __restrict__ p_b_global,
CFloat* const __restrict__ p_c_global) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_g_k_m_kpack_global_desc = AGlobalDesc{};
constexpr auto b_g_k_n_kpack_global_desc = BGlobalDesc{};
constexpr auto c_g_m_n_global_desc = CGlobalDesc{};
constexpr auto G = c_g_m_n_global_desc.GetLengths()[0];
constexpr auto M = c_g_m_n_global_desc.GetLengths()[1];
constexpr auto N = c_g_m_n_global_desc.GetLengths()[2];
constexpr auto K = b_g_k_n_kpack_global_desc.GetLengths()[1];
constexpr auto KPack = b_g_k_n_kpack_global_desc.GetLengths()[3];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr index_t MWavePerBlock = MPerBlock / MPerWave;
constexpr index_t NWavePerBlock = NPerBlock / NPerWave;
constexpr auto block_work_sequence =
make_batch_block_work_sequence<G, MBlockWork, NBlockWork, WorkgroupSchdOrder>{}.get();
constexpr auto block_work_desc = make_cluster_descriptor(block_work_sequence);
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t g_block_data_on_global = block_work_id[0];
const index_t m_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[1] * MPerBlock)
: (block_work_id[2] * MPerBlock);
const index_t n_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[2] * NPerBlock)
: (block_work_id[1] * NPerBlock);
constexpr index_t max_align = KPack;
// LDS be careful of LDS alignment
constexpr auto a_g_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, MPerBlock, KPack>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_g_k_m_kpack_global_desc),
decltype(a_g_k_m_kpack_block_desc),
decltype(a_g_k_m_kpack_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK,
ABlockCopyThreadClusterLengths_G_K_M_KPACK,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim, // Src dim to be read in vector form
3, // Dst dim to be written in vector form (KPack dimension)
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({g_block_data_on_global, 0, m_block_data_on_global, 0},
{0, 0, 0, 0});
constexpr auto b_g_k_n_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, NPerBlock, KPack>{}, Number<max_align>{});
// input blockwise copy
auto b_blockwise_copy = BlockwiseGenericTensorSliceCopy_v5<
BlockSize,
decltype(b_g_k_n_kpack_global_desc),
decltype(b_g_k_n_kpack_block_desc),
decltype(b_g_k_n_kpack_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_G_K_N_KPACK,
BBlockCopyThreadClusterLengths_G_K_N_KPACK,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim, // Src dim to be read in vector form
3, // Dst dim to be written in vector form (KPack dimension)
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({g_block_data_on_global, 0, n_block_data_on_global, 0},
{0, 0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<MPerBlock>{});
constexpr auto b_k_n_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<NPerBlock>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
ABFloat,
MPerWave,
NPerWave,
MWavePerBlock,
NWavePerBlock,
1,
1>{};
constexpr index_t a_block_space =
math::integer_least_multiple(a_g_k_m_kpack_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_g_k_n_kpack_block_desc.GetElementSpace(), max_align);
__shared__ ABFloat p_a_block[a_block_space];
__shared__ ABFloat p_b_block[b_block_space];
// get zero-initialized output register of vector type
auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
// preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block);
b_blockwise_copy.Run(p_b_global, p_b_block);
}
constexpr auto blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>{};
constexpr auto blockwise_b_copy_src_step = Sequence<0, KPerBlock, 0, 0>{};
// main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += KPerBlock)
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
// ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
// load next data from device mem
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
block_sync_lds();
// GEMM on current data
const typename vector_type<ABFloat, KPack>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_a_block);
const typename vector_type<ABFloat, KPack>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_b_block);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
block_sync_lds();
// store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(p_b_block);
}
// tail
{
block_sync_lds();
// GEMM on last data
const typename vector_type<ABFloat, KPack>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_a_block);
const typename vector_type<ABFloat, KPack>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_b_block);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_g_m0_m1_m2_n_global_desc = transform_tensor_descriptor(
c_g_m_n_global_desc,
make_tuple(
PassThrough<G>{}, UnMerge<Sequence<M / (M1 * M2), M1, M2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
// src descriptor
constexpr auto c_g_m0_m1_m2_n_thread_desc =
make_native_tensor_descriptor_packed(Sequence<1, M0, 1, M2, 1>{});
using CThreadCopySliceLengths = Sequence<1, M0, 1, M2, 1>;
constexpr index_t BlkSize = blockwise_gemm.GetBlkSize();
constexpr index_t NumBlks = blockwise_gemm.GetNumBlks();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryOp>(
{0, 0, 0, 0, 0},
{g_block_data_on_global,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
}
}
};
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t BPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t BPerWave,
class ABlockCopyThreadSliceLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_KPACK,
class BBlockCopyThreadSliceLengths_G_K_N1_B_KPack,
class BBlockCopyThreadClusterLengths_G_K_N1_B_KPack,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_KPACK,
InMemoryDataOperation CGlobalMemoryOp,
WorkgroupScheduleOrder WorkgroupSchdOrder>
struct GridwiseBatchGemmXdlops_gkmkpack_gkn1bkpack_gmn_v2
{
__device__ void Run(const ABFloat* const __restrict__ p_a_global,
const ABFloat* const __restrict__ p_b_global,
CFloat* const __restrict__ p_c_global) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_g_k_m_kpack_global_desc = AGlobalDesc{};
constexpr auto b_g_k_n1_b_kpack_global_desc = BGlobalDesc{};
constexpr auto c_g_m_n_global_desc = CGlobalDesc{};
constexpr auto G = c_g_m_n_global_desc.GetLengths()[0];
constexpr auto M = c_g_m_n_global_desc.GetLengths()[1];
constexpr auto N = c_g_m_n_global_desc.GetLengths()[2];
constexpr auto K = b_g_k_n1_b_kpack_global_desc.GetLengths()[1];
constexpr auto in_N1 = b_g_k_n1_b_kpack_global_desc.GetLengths()[2];
constexpr auto B = b_g_k_n1_b_kpack_global_desc.GetLengths()[3];
constexpr auto KPack = b_g_k_n1_b_kpack_global_desc.GetLengths()[4];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && B % BPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr index_t MWavePerBlock = MPerBlock / MPerWave;
constexpr index_t BWavePerBlock = in_N1;
static_assert((G * MBlockWork * BBlockWork) == GridSize, "Invalid GridSize");
constexpr auto block_work_sequence =
make_batch_block_work_sequence<G, MBlockWork, BBlockWork, WorkgroupSchdOrder>{}.get();
constexpr auto block_work_desc = make_cluster_descriptor(block_work_sequence);
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t g_block_data_on_global = block_work_id[0];
const index_t m_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[1] * MPerBlock)
: (block_work_id[2] * MPerBlock);
const index_t b_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[2] * BPerBlock)
: (block_work_id[1] * BPerBlock);
constexpr index_t max_align = KPack;
// LDS be careful of LDS alignment
constexpr auto a_g_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, MPerBlock, KPack>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_g_k_m_kpack_global_desc),
decltype(a_g_k_m_kpack_block_desc),
decltype(a_g_k_m_kpack_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK,
ABlockCopyThreadClusterLengths_G_K_M_KPACK,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim, // Src dim to be read in vector form
3, // Dst dim to be written in vector form (KPack dimension)
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({g_block_data_on_global, 0, m_block_data_on_global, 0},
{0, 0, 0, 0});
constexpr auto b_g_k_n1_b_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, in_N1, BPerBlock, KPack>{}, Number<max_align>{});
// input blockwise copy
auto b_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(b_g_k_n1_b_kpack_global_desc),
decltype(b_g_k_n1_b_kpack_block_desc),
decltype(b_g_k_n1_b_kpack_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_G_K_N1_B_KPack,
BBlockCopyThreadClusterLengths_G_K_N1_B_KPack,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim, // Src dim to be read in vector form
4, // Dst dim to be written in vector form (KPack dimension)
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({g_block_data_on_global, 0, 0, b_block_data_on_global, 0},
{0, 0, 0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, BPerBlock * in_N1] is in LDS
// c_mtx[MPerBlock, BPerBlock * in_N1] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<MPerBlock>{});
constexpr auto b_k_n_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<BPerBlock * in_N1>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
ABFloat,
MPerWave,
BPerWave,
MWavePerBlock,
BWavePerBlock,
1,
1>{};
constexpr index_t a_block_space =
math::integer_least_multiple(a_g_k_m_kpack_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_g_k_n1_b_kpack_block_desc.GetElementSpace(), max_align);
__shared__ ABFloat p_a_block[a_block_space];
__shared__ ABFloat p_b_block[b_block_space];
// get zero-initialized output register of vector type
auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
// preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block);
b_blockwise_copy.Run(p_b_global, p_b_block);
}
constexpr auto blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>{};
constexpr auto blockwise_b_copy_src_step = Sequence<0, KPerBlock, 0, 0, 0>{};
// main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += KPerBlock)
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
// load next data from device mem
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
block_sync_lds();
// GEMM on current data
const typename vector_type<ABFloat, KPack>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_a_block);
const typename vector_type<ABFloat, KPack>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_b_block);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
block_sync_lds();
// store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block);
}
// tail
{
block_sync_lds();
// GEMM on last data
const typename vector_type<ABFloat, KPack>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_a_block);
const typename vector_type<ABFloat, KPack>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_b_block);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_g_m0_m1_m2_n_global_desc = transform_tensor_descriptor(
c_g_m_n_global_desc,
make_tuple(
PassThrough<G>{}, UnMerge<Sequence<M / (M1 * M2), M1, M2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
// src descriptor
constexpr auto c_g_m0_m1_m2_n_thread_desc =
make_native_tensor_descriptor_packed(Sequence<1, M0, 1, M2, 1>{});
using CThreadCopySliceLengths = Sequence<1, M0, 1, M2, 1>;
constexpr index_t BlkSize = blockwise_gemm.GetBlkSize();
constexpr index_t NumBlks = blockwise_gemm.GetNumBlks();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.template GetBeginOfThreadMatrixC<MPerWave, B>(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryOp>(
{0, 0, 0, 0, 0},
{g_block_data_on_global,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
}
}
};
} // namespace ck
#endif
#ifndef CK_XDLOPS_GEMM_HPP
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
namespace ck {
enum struct mfma_instr
{
// fp32
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction
// fp16
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction
// bfp16
mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16,
mfma_f32_32x32x4bf16, // k reduction
mfma_f32_16x16x8bf16, // k reduction
};
template <mfma_instr instr>
struct mfma_info;
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 1;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_32x32x2f32(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_16x16x4f32(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 1;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
};
// treat 4x4x1 as a single-blk 4x64 mfma
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 1;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 8;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_32x32x8f16(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 16;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_16x16x16f16(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 4;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 8;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 2;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 2;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
};
template <mfma_instr instr,
index_t MPerXdlops_,
index_t NPerXdlops_,
index_t MRepeats_,
index_t NRepeats_,
class OutputVecType_>
struct xdlops_info
{
static constexpr auto mfma_type = mfma_info<instr>{};
static constexpr index_t MPerXdlops = MPerXdlops_;
static constexpr index_t NPerXdlops = NPerXdlops_;
static constexpr index_t MRepeats = MRepeats_;
static constexpr index_t NRepeats = NRepeats_;
static constexpr bool IsABroadcast() { return NPerXdlops >= MPerXdlops; }
static constexpr bool IsKReduction()
{
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1);
}
static constexpr auto OutputVecType = OutputVecType_{};
};
template <class data_type,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB>
struct XdlopsGemm_t
{
struct MatrixIndex
{
index_t row;
index_t col;
};
__device__ static constexpr index_t GetNumBlksPerXdlops()
{
return (MPerXdlops * NPerXdlops) / (mfma_type.m * mfma_type.n);
}
__device__ constexpr XdlopsGemm_t()
{
static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
NPerXdlops == 64,
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
static_assert(mfma_type.k % mfma_type.k_base == 0, "k and k_base is inconsistent!");
}
__device__ static constexpr index_t GetRegSizePerXdlops()
{
return MPerXdlops * NPerXdlops / mfma_type.wave_size;
}
#if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ FloatC XdlopsEmulate(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC p_c_thread) const
{
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
// K reduction
static_if<IsKReduction>{}([&](auto) {
for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = (k + n) * M;
index_t b_off = (k + n) * N;
index_t c_off = 0;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex = m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] += inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}).Else([&](auto) {
static_if<IsABroadcast>{}([&](auto) {
for(index_t m_i = 0; m_i < MRepeats; ++m_i)
{
for(index_t n_i = 0; n_i < NRepeats; ++n_i)
{
// ABroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < MPerXdlops / mfma_type.m; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + b * mfma_type.m + MPerXdlops * m_i;
index_t b_off =
k * N + n * mfma_type.num_threads_blk + NPerXdlops * n_i;
index_t c_off = n * mfma_type.num_regs_blk +
b * mfma_type.num_regs_xdlops +
(NRepeats * m_i + n_i) * GetRegSizePerXdlops();
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size +
blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] +=
inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}
}
}
}).Else([&](auto) {
// BBroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < NPerXdlops / mfma_type.n; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + n * mfma_type.m;
index_t b_off = k * N + b * mfma_type.n;
index_t c_off =
n * mfma_type.num_regs_blk + b * mfma_type.num_regs_xdlops;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] += inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}
});
});
return p_c_thread;
}
#endif
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ FloatC Run(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC p_c_thread) const
{
static_assert(is_same<FloatA, FloatB>::value, "FloatA != FloatB");
static_assert(is_same<data_type, float>::value || is_same<data_type, half_t>::value ||
is_same<data_type, ushort>::value,
"base data_type must be float, half, ushort!");
#if CK_USE_AMD_XDLOPS_EMULATE
p_c_thread = XdlopsEmulate<M, N, K>(p_a_wave, p_b_wave, p_c_thread);
#else
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
FloatA a[K * MRepeats];
FloatB b[K * NRepeats];
static_assert(sizeof(FloatA) % (sizeof(data_type) * mfma_type.k_base) == 0,
"wrong! FloatA is consistent with mfma");
static_assert(!IsKReduction || K % mfma_type.num_input_blks == 0,
"K cannot divided by mfma_type.num_input_blks!");
static_assert(!IsKReduction || (MRepeats == 1 && NRepeats == 1),
"KReduction does not support M/N Repeats!");
constexpr index_t KRepeats = sizeof(FloatA) / (sizeof(data_type) * mfma_type.k_base);
auto pa = reinterpret_cast<const data_type*>(&a);
auto pb = reinterpret_cast<const data_type*>(&b);
constexpr index_t AStride = K * KRepeats;
constexpr index_t BStride = K * KRepeats;
static_if<!IsKReduction>{}([&](auto) {
for(index_t m_i = 0; m_i < MRepeats; ++m_i)
for(index_t k_i = 0; k_i < K; ++k_i)
a[k_i + m_i * K] = p_a_wave[k_i * M + laneId + MPerXdlops * m_i];
for(index_t n_i = 0; n_i < NRepeats; ++n_i)
for(index_t k_i = 0; k_i < K; ++k_i)
b[k_i + n_i * K] = p_b_wave[k_i * N + laneId + NPerXdlops * n_i];
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif
for(index_t k_i = 0; k_i < K * KRepeats; ++k_i)
{
p_c_thread = mfma_type.template run<MPerXdlops * MRepeats,
NPerXdlops * NRepeats,
AStride,
BStride>(
&pa[k_i * mfma_type.k_base], &pb[k_i * mfma_type.k_base], p_c_thread);
}
}).Else([&](auto) {
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
// load into registers
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
{
a[k_i] = p_a_wave[(k_i + blk_id) * M + blk_td];
b[k_i] = p_b_wave[(k_i + blk_id) * N + blk_td];
}
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
{
for(index_t i = 0; i < KRepeats; ++i)
p_c_thread = mfma_type.template run<MPerXdlops, NPerXdlops, AStride, BStride>(
&pa[(k_i * KRepeats + i) * mfma_type.k_base],
&pb[(k_i * KRepeats + i) * mfma_type.k_base],
p_c_thread);
}
});
#endif
return p_c_thread;
}
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i)
{
const index_t xdlops_i = i / GetNumBlksPerXdlops();
const index_t j = i % GetNumBlksPerXdlops();
const index_t m_i = xdlops_i / NRepeats;
const index_t n_i = xdlops_i % NRepeats;
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t col_blk = j % mfma_type.num_output_blks;
index_t row_blk = j / mfma_type.num_output_blks;
static_if<!IsABroadcast>{}([&](auto) {
col_blk = j / mfma_type.num_output_blks;
row_blk = j % mfma_type.num_output_blks;
});
index_t col = col_blk * mfma_type.n + blk_td + n_i * NPerXdlops;
index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size + m_i * MPerXdlops;
return MatrixIndex{row, col};
}
__device__ void SetZeroXdlopsRegs() const {}
template <class FloatC>
__device__ void ReadXdlopsRegs(FloatC* const __restrict__) const
{
}
template <class data_type_ = data_type,
index_t MPerWave_ = GemmMPerWave,
index_t NPerWave_ = GemmNPerWave>
static constexpr auto GetXdlopsInfo();
template <>
static constexpr auto GetXdlopsInfo<float, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 1, c_vec32_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 64, 16, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64, 1, 1, c_vec4_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 1, c_vec32_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 64, 16, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64, 1, 1, c_vec4_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
}
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats;
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
struct OutputLayout
{
__device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__device__ static constexpr index_t M0() { return mfma_type.group_size; }
__device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; }
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; }
__device__ static constexpr index_t GetNumBlks()
{
return GetNumBlksPerXdlops() * MRepeats * NRepeats;
}
__device__ static constexpr auto CreateOutputVecZero()
{
return GetXdlopsInfo().OutputVecType.CreateVecZero();
}
};
__device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; }
};
} // namespace ck
#endif
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "float_type.hpp"
namespace ck {
// A, B, C, cbsz, abid, blgp
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 32, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<32, 64, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
}
};
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x4f32(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_16x16x1f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x1f32<16, 64>(const float* reg_a,
const float* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x1f32<64, 16>(const float* reg_a,
const float* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x1f32;
template <>
struct intrin_mfma_f32_4x4x1f32<4, 64>
{
__device__ static c_vec4_1_t::VecType
run(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
}
};
template <>
struct intrin_mfma_f32_4x4x1f32<8, 64>
{
__device__ static c_vec4_2_t::VecType
run(const float* reg_a, const float* reg_b, c_vec4_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
}
};
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 32, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<32, 64, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
}
};
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_32x32x8f16(const half4_t* reg_a, const half4_t* reg_b, c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x8f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x16f16(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x16f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_16x16x4f16(const half4_t* reg_a, const half4_t* reg_b, c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x4f16<16, 64>(const half4_t* reg_a,
const half4_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x4f16<64, 16>(const half4_t* reg_a,
const half4_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x4f16;
template <>
struct intrin_mfma_f32_4x4x4f16<4, 64>
{
__device__ static c_vec4_1_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
}
};
template <>
struct intrin_mfma_f32_4x4x4f16<8, 64>
{
__device__ static c_vec4_2_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec4_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
}
};
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
}
};
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x2bf16;
template <>
struct intrin_mfma_f32_4x4x2bf16<4, 64>
{
__device__ static c_vec4_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
}
};
template <>
struct intrin_mfma_f32_4x4x2bf16<8, 64>
{
__device__ static c_vec4_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
}
};
}
#endif
#include "common_header.hpp"
#include "gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "float_types.h"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
// read params: problem description
constexpr index_t G = CK_PARAM_PROBLEM_G;
constexpr index_t N = CK_PARAM_PROBLEM_N;
constexpr index_t K = CK_PARAM_PROBLEM_K;
constexpr index_t C = CK_PARAM_PROBLEM_C;
constexpr index_t Hi = CK_PARAM_PROBLEM_HI;
constexpr index_t Wi = CK_PARAM_PROBLEM_WI;
constexpr index_t Ho = CK_PARAM_PROBLEM_HO;
constexpr index_t Wo = CK_PARAM_PROBLEM_WO;
constexpr index_t Y = CK_PARAM_PROBLEM_Y;
constexpr index_t X = CK_PARAM_PROBLEM_X;
constexpr index_t ConvStrideH = CK_PARAM_PROBLEM_CONV_STRIDE_H;
constexpr index_t ConvStrideW = CK_PARAM_PROBLEM_CONV_STRIDE_W;
constexpr index_t ConvDilationH = CK_PARAM_PROBLEM_CONV_DILATION_H;
constexpr index_t ConvDilationW = CK_PARAM_PROBLEM_CONV_DILATION_W;
constexpr index_t InLeftPadH = CK_PARAM_PROBLEM_IN_LEFT_PAD_H;
constexpr index_t InLeftPadW = CK_PARAM_PROBLEM_IN_LEFT_PAD_W;
constexpr index_t InRightPadH = CK_PARAM_PROBLEM_IN_RIGHT_PAD_H;
constexpr index_t InRightPadW = CK_PARAM_PROBLEM_IN_RIGHT_PAD_W;
constexpr auto CPerGroup = C / G;
constexpr auto in_n_c_hi_wi_desc =
make_native_tensor_descriptor_packed(Sequence<N, C, Hi, Wi>{});
constexpr auto wei_k_cpergroup_y_x_desc =
make_native_tensor_descriptor_packed(Sequence<K, CPerGroup, Y, X>{});
constexpr auto out_n_k_ho_wo_desc =
make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
using ConvStrides = Sequence<ConvStrideH, ConvStrideW>;
using ConvDilations = Sequence<ConvDilationH, ConvDilationW>;
using InLeftPads = Sequence<InLeftPadH, InLeftPadW>;
using InRightPads = Sequence<InRightPadH, InRightPadW>;
// read params: tunning parameters
constexpr index_t GemmMPerBlock = CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK;
constexpr index_t GemmNPerBlock = CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK;
constexpr index_t GemmKPerBlock = CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK;
constexpr index_t GemmMPerWave = CK_PARAM_TUNABLE_GEMM_M_PER_WAVE;
constexpr index_t GemmNPerWave = CK_PARAM_TUNABLE_GEMM_N_PER_WAVE;
constexpr index_t GemmKPack = CK_PARAM_TUNABLE_GEMM_KPACK;
// read params: dependent parameters
constexpr index_t BlockSize = CK_PARAM_DEPENDENT_BLOCK_SIZE;
constexpr index_t GridSize = CK_PARAM_DEPENDENT_GRID_SIZE;
// A matrix copy
constexpr index_t GemmABlockCopyClusterLengths_GemmK =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
constexpr index_t GemmABlockCopyClusterLengths_GemmM =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmM =
GemmMPerBlock / GemmABlockCopyClusterLengths_GemmM;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmKPack =
GemmKPack / GemmABlockCopyClusterLengths_GemmKPack;
using GemmABlockCopyClusterLengths_GemmG_GemmK_GemmM_GemmKPack =
Sequence<1,
GemmABlockCopyClusterLengths_GemmK,
GemmABlockCopyClusterLengths_GemmM,
GemmABlockCopyClusterLengths_GemmKPack>;
using GemmABlockCopySubLengths_GemmG_GemmK_GemmM_GemmKPack =
Sequence<1,
GemmABlockCopyThreadSliceLengths_GemmK,
GemmABlockCopyThreadSliceLengths_GemmM,
GemmABlockCopyThreadSliceLengths_GemmKPack>;
using GemmABlockCopyThreadClusterArrangeOrder =
Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack]
using GemmABlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack]
using GemmABlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmM, GemmKPack]
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_KPACK;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
// B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmN =
GemmNPerBlock / GemmBBlockCopyClusterLengths_GemmN;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmKPack =
GemmKPack / GemmBBlockCopyClusterLengths_GemmKPack;
using GemmBBlockCopyClusterLengths_GemmG_GemmK_GemmN_GemmKPack =
Sequence<1,
GemmBBlockCopyClusterLengths_GemmK,
GemmBBlockCopyClusterLengths_GemmN,
GemmBBlockCopyClusterLengths_GemmKPack>;
using GemmBBlockCopySubLengths_GemmG_GemmK_GemmN_GemmKPack =
Sequence<1,
GemmBBlockCopyThreadSliceLengths_GemmK,
GemmBBlockCopyThreadSliceLengths_GemmN,
GemmBBlockCopyThreadSliceLengths_GemmKPack>;
using GemmBBlockCopyThreadClusterArrangeOrder =
Sequence<0, 1, 3, 2>; // [GemmG, GemmK, GemmKPack, GemmN]
using GemmBBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [GemmG, GemmK, GemmKPack, GemmN]
using GemmBBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmN, GemmKPack]
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
// gridwise GEMM
constexpr auto wkgrp_schd_order = NBlock1MBlock0;
constexpr auto gridwise_conv =
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw<
GridSize,
BlockSize,
FLOAT, // Input data type
FLOAT_ACCUM, // Acc data type
FLOAT, // Ouput data type
decltype(in_n_c_hi_wi_desc),
decltype(wei_k_cpergroup_y_x_desc),
decltype(out_n_k_ho_wo_desc),
G,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPack,
GemmABlockCopySubLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyClusterLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyThreadClusterArrangeOrder,
GemmABlockCopySrcAccessOrder,
GemmABlockCopyDstAccessOrder,
GemmABlockCopySrcDataPerRead_GemmKPack,
GemmABlockCopyDstDataPerWrite_GemmKPack,
GemmBBlockCopySubLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyClusterLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyThreadClusterArrangeOrder,
GemmBBlockCopySrcAccessOrder,
GemmBBlockCopyDstAccessOrder,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPack,
wkgrp_schd_order>{};
gridwise_conv.Run(p_in_global, p_wei_global, p_out_global);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment