Commit 87a75734 authored by Jing Zhang's avatar Jing Zhang
Browse files

adding xdlops

parent 7972ab17
#ifndef CK_GRIDWISE_GROUP_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_GROUP_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "gridwise_gemm_xdlops_fp16_bfp16.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t G,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPack,
class GemmABlockCopyThreadSliceLengths_GemmG_GemmK_GemmM_GemmKPack,
class GemmABlockCopyThreadClusterLengths_GemmG_GemmK_GemmM_GemmKPack,
class GemmABlockCopyThreadClusterArrangeOrder,
class GemmABlockCopySrcAccessOrder,
class GemmABlockCopyDstAccessOrder,
index_t GemmABlockCopySrcDataPerRead_GemmKPack,
index_t GemmABlockCopyDstDataPerWrite_GemmKPack,
class GemmBBlockCopyThreadSliceLengths_GemmG_GemmK_GemmN_GemmKPack,
class GemmBBlockCopyThreadClusterLengths_GemmG_GemmK_GemmN_GemmKPack,
class GemmBBlockCopyThreadClusterArrangeOrder,
class GemmBBlockCopySrcAccessOrder,
class GemmBBlockCopyDstAccessOrder,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmKPack,
WorkgroupScheduleOrder WorkgroupSchdOrder>
struct GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
{
__device__ void Run(const ABFloat* const __restrict__ p_in_global,
const ABFloat* const __restrict__ p_wei_global,
CFloat* const __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_cpergroup_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_cpergroup_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_cpergroup_y_x_global_desc.GetLengths()[3];
constexpr index_t CPerGroup = C / G;
constexpr index_t KPerGroup = K / G;
static_assert(CPerGroup == wei_k_cpergroup_y_x_global_desc.GetLengths()[1], "wrong!");
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GemmG = G;
constexpr index_t GemmM = KPerGroup;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GemmKTotal = CPerGroup * Y * X;
static_assert(GemmKTotal % GemmKPack == 0,
"wrong! GemmKTotal should be multiple of GemmKPack");
constexpr index_t GemmK = GemmKTotal / GemmKPack;
static_assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0,
"wrong! cannot divide work evenly among block");
// construct tensor descriptor for group convolution
constexpr auto in_g_n_cpergroup_hi_wi_global_desc = make_native_tensor_descriptor(
Sequence<G, N, CPerGroup, Hi, Wi>{},
Sequence<CPerGroup * Hi * Wi, C * Hi * Wi, Hi * Wi, Wi, 1>{});
constexpr auto wei_g_kpergroup_cpergroup_y_x_global_desc =
make_native_tensor_descriptor_packed(Sequence<G, KPerGroup, CPerGroup, Y, X>{});
constexpr auto out_g_n_kpergroup_ho_wo_global_desc = make_native_tensor_descriptor(
Sequence<G, N, KPerGroup, Ho, Wo>{},
Sequence<KPerGroup * Ho * Wo, K * Ho * Wo, Ho * Wo, Wo, 1>{});
// input tensor
constexpr auto in_g_n_cpergroup_hip_wip_global_desc = transform_tensor_descriptor(
in_g_n_cpergroup_hi_wi_global_desc,
make_tuple(PassThrough<G>{},
PassThrough<N>{},
PassThrough<CPerGroup>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
constexpr index_t Hip = in_g_n_cpergroup_hip_wip_global_desc.GetLengths()[3];
constexpr index_t Wip = in_g_n_cpergroup_hip_wip_global_desc.GetLengths()[4];
constexpr auto in_g_n_cpergroup_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_g_n_cpergroup_hip_wip_global_desc,
make_tuple(PassThrough<G>{},
PassThrough<N>{},
PassThrough<CPerGroup>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}));
constexpr auto in_gemmg_gemmktotal_gemmn_global_desc = transform_tensor_descriptor(
in_g_n_cpergroup_y_ho_x_wo_global_desc,
make_tuple(PassThrough<G>{}, Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<0>{}, Sequence<2, 3, 5>{}, Sequence<1, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto in_gemmg_gemmk_gemmn_gemmkpack_global_desc = transform_tensor_descriptor(
in_gemmg_gemmktotal_gemmn_global_desc,
make_tuple(
PassThrough<GemmG>{}, UnMerge<Sequence<GemmK, GemmKPack>>{}, PassThrough<GemmN>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
// weight tensor
constexpr auto wei_gemmg_gemmm_gemmktotal_global_desc = unfold_tensor_descriptor(
wei_g_kpergroup_cpergroup_y_x_global_desc, Number<2>{}, Number<4>{});
constexpr auto wei_gemmg_gemmk_gemmm_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmg_gemmm_gemmktotal_global_desc,
make_tuple(
PassThrough<GemmG>{}, PassThrough<GemmM>{}, UnMerge<Sequence<GemmK, GemmKPack>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1, 3>{}));
// output tensor
constexpr auto out_gemmg_gemmm_gemmn_global_desc = transform_tensor_descriptor(
out_g_n_kpergroup_ho_wo_global_desc,
make_tuple(PassThrough<G>{}, PassThrough<KPerGroup>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// gridwise batch-GEMM
constexpr auto gridwise_gemm = GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2<
GridSize,
BlockSize,
ABFloat,
AccFloat,
CFloat,
decltype(wei_gemmg_gemmk_gemmm_gemmkpack_global_desc),
decltype(in_gemmg_gemmk_gemmn_gemmkpack_global_desc),
decltype(out_gemmg_gemmm_gemmn_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmABlockCopyThreadSliceLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyThreadClusterLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyThreadClusterArrangeOrder,
GemmABlockCopySrcAccessOrder,
GemmABlockCopyDstAccessOrder,
3, // src vector read dimension of A matrix is GemmKPack
GemmABlockCopySrcDataPerRead_GemmKPack,
GemmABlockCopyDstDataPerWrite_GemmKPack,
GemmBBlockCopyThreadSliceLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyThreadClusterLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyThreadClusterArrangeOrder,
GemmBBlockCopySrcAccessOrder,
GemmBBlockCopyDstAccessOrder,
2, // Src vetor read diemsnion of B matrix is GemmN
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPack,
InMemoryDataOperation::Set,
WorkgroupSchdOrder>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class Float,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA, // \todo unused parameter, remove
index_t GemmDataPerReadB // \todo unused parameter, remove
>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct MatrixIndex
{
index_t row;
index_t col;
};
#if CK_WORKAROUND_SWDEV_241664
static constexpr index_t MRepeats = (GemmMPerWave > 64) ? (GemmMPerWave / 64) : 1;
static constexpr index_t NRepeats = (GemmNPerWave > 64) ? (GemmNPerWave / 64) : 1;
static constexpr index_t MPerXdlops = (GemmMPerWave > 64) ? 64 : GemmMPerWave;
static constexpr index_t NPerXdlops = (GemmNPerWave > 64) ? 64 : GemmNPerWave;
static constexpr auto XdlopsGemm =
XdlopsGemm_t<Float, MPerXdlops, NPerXdlops, GemmDataPerReadA, GemmDataPerReadB>{};
#else
#if CK_USE_AMD_XDLOPS_INLINE_ASM
/// \to-do add inline support for vector type c
static_assert(false, "Does not support inline asm for vector type c")
#else
static constexpr auto XdlopsGemm =
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
#endif
#endif
index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB;
static constexpr index_t WaveSize = 64;
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm.GetOutputLayout(); }
#if CK_WORKAROUND_SWDEV_241664
template <index_t MRepeats_ = MRepeats, index_t NRepeats_ = NRepeats>
__device__ constexpr auto CreateOutputVecZero() const;
template <>
__device__ constexpr auto CreateOutputVecZero<2, 1>() const
{
return c_vec32_2_2_t::CreateVecZero();
}
template <>
__device__ constexpr auto CreateOutputVecZero<1, 2>() const
{
return c_vec32_2_2_t::CreateVecZero();
}
template <>
__device__ constexpr auto CreateOutputVecZero<1, 1>() const
{
return XdlopsGemm.GetOutputLayout().CreateOutputVecZero();
}
#else
__device__ constexpr auto CreateOutputVecZero() const
{
return XdlopsGemm.GetOutputLayout().CreateOutputVecZero();
}
#endif
__device__ constexpr auto GetNumBlks() const
{
#if CK_WORKAROUND_SWDEV_241664
return XdlopsGemm.GetOutputLayout().GetNumBlks() * MRepeats * NRepeats;
#else
return XdlopsGemm.GetOutputLayout().GetNumBlks();
#endif
}
__device__ constexpr auto GetBlkSize() const
{
return XdlopsGemm.GetOutputLayout().GetBlkSize();
}
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops()
{
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
static_assert(GemmMPerWave * GemmMWaves == M, "GemmMWaves * GemmMPerWave != M");
static_assert(GemmNPerWave * GemmNWaves == N, "GemmNWaves * GemmNPerWave != N");
static_assert(BlockSize == GemmMWaves * GemmNWaves * WaveSize,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize\n");
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const index_t waveId_m = waveId / GemmNWaves;
const index_t waveId_n = waveId % GemmNWaves;
mMyWaveOffsetA = waveId_m * GemmMPerWave;
mMyWaveOffsetB = waveId_n * GemmNPerWave;
}
#if CK_WORKAROUND_SWDEV_241664
template <index_t MRepeats_, index_t NRepeats_>
struct WithMNRepeats;
template <>
struct WithMNRepeats<2, 1>
{
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ static FloatC Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread)
{
p_c_thread.s.x.l =
XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread.s.x.l);
p_c_thread.s.y.l = XdlopsGemm.template Run<M, N, K>(
p_a_block + MPerXdlops, p_b_block, p_c_thread.s.y.l);
return p_c_thread;
}
};
template <>
struct WithMNRepeats<1, 2>
{
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ static FloatC Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread)
{
p_c_thread.s.x.l =
XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread.s.x.l);
p_c_thread.s.y.l = XdlopsGemm.template Run<M, N, K>(
p_a_block, p_b_block + NPerXdlops, p_c_thread.s.y.l);
return p_c_thread;
}
};
template <>
struct WithMNRepeats<1, 1>
{
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ static FloatC Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread)
{
return XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread);
}
};
#endif
template <class FloatA, class FloatB, class FloatC>
__device__ FloatC Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread) const
{
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow();
#if CK_WORKAROUND_SWDEV_241664
return WithMNRepeats<MRepeats, NRepeats>::template Run<M, N, K>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
#else
return XdlopsGemm.template Run<M, N, K>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
#endif
}
template <index_t AStride = GemmMPerWave, index_t BStride = GemmNPerWave>
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
#if CK_WORKAROUND_SWDEV_241664
const index_t xdlops_i = i / XdlopsGemm.GetOutputLayout().GetNumBlks();
const index_t j = i % XdlopsGemm.GetOutputLayout().GetNumBlks();
const index_t m = xdlops_i / NRepeats;
const index_t n = xdlops_i % NRepeats;
const auto thread_mtx_on_blk = XdlopsGemm.GetBeginOfThreadBlk(j);
const index_t col =
(waveId % GemmNWaves) * BStride + n * NPerXdlops + thread_mtx_on_blk.col;
const index_t row =
(waveId / GemmNWaves) * AStride + m * MPerXdlops + thread_mtx_on_blk.row;
#else
const auto thread_mtx_on_blk = XdlopsGemm.GetBeginOfThreadBlk(i);
const index_t col = (waveId % GemmNWaves) * BStride + thread_mtx_on_blk.col;
const index_t row = (waveId / GemmNWaves) * AStride + thread_mtx_on_blk.row;
#endif
return MatrixIndex{row, col};
}
__device__ constexpr auto GetThreadMatrixCDescriptor() const
{
const index_t total_reg_size = GemmMPerWave * GemmNPerWave / WaveSize;
return make_ConstantMatrixDescriptor_packed(Number<total_reg_size>{}, Number<1>{});
}
__device__ void XdlopsMatrixCSetZero() const { XdlopsGemm.SetZeroXdlopsRegs(); }
template <class FloatC>
__device__ void XdlopsMatrixCRead(FloatC* __restrict__ p_c_thread) const
{
XdlopsGemm.ReadXdlopsRegs(p_c_thread);
}
};
} // namespace ck
#endif
This diff is collapsed.
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "float_type.hpp"
namespace ck {
// A, B, C, cbsz, abid, blgp
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 32, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<32, 64, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
}
};
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x4f32(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_16x16x1f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x1f32<16, 64>(const float* reg_a,
const float* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x1f32<64, 16>(const float* reg_a,
const float* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x1f32;
template <>
struct intrin_mfma_f32_4x4x1f32<4, 64>
{
__device__ static c_vec4_1_t::VecType
run(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
}
};
template <>
struct intrin_mfma_f32_4x4x1f32<8, 64>
{
__device__ static c_vec4_2_t::VecType
run(const float* reg_a, const float* reg_b, c_vec4_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
}
};
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 32, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<32, 64, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
}
};
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_32x32x8f16(const half4_t* reg_a, const half4_t* reg_b, c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x8f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x16f16(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x16f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_16x16x4f16(const half4_t* reg_a, const half4_t* reg_b, c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x4f16<16, 64>(const half4_t* reg_a,
const half4_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x4f16<64, 16>(const half4_t* reg_a,
const half4_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x4f16;
template <>
struct intrin_mfma_f32_4x4x4f16<4, 64>
{
__device__ static c_vec4_1_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
}
};
template <>
struct intrin_mfma_f32_4x4x4f16<8, 64>
{
__device__ static c_vec4_2_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec4_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
}
};
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
}
};
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x2bf16;
template <>
struct intrin_mfma_f32_4x4x2bf16<4, 64>
{
__device__ static c_vec4_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
}
};
template <>
struct intrin_mfma_f32_4x4x2bf16<8, 64>
{
__device__ static c_vec4_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
}
};
}
#endif
#include "common_header.hpp"
#include "gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "float_types.h"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
// read params: problem description
constexpr index_t G = CK_PARAM_PROBLEM_G;
constexpr index_t N = CK_PARAM_PROBLEM_N;
constexpr index_t K = CK_PARAM_PROBLEM_K;
constexpr index_t C = CK_PARAM_PROBLEM_C;
constexpr index_t Hi = CK_PARAM_PROBLEM_HI;
constexpr index_t Wi = CK_PARAM_PROBLEM_WI;
constexpr index_t Ho = CK_PARAM_PROBLEM_HO;
constexpr index_t Wo = CK_PARAM_PROBLEM_WO;
constexpr index_t Y = CK_PARAM_PROBLEM_Y;
constexpr index_t X = CK_PARAM_PROBLEM_X;
constexpr index_t ConvStrideH = CK_PARAM_PROBLEM_CONV_STRIDE_H;
constexpr index_t ConvStrideW = CK_PARAM_PROBLEM_CONV_STRIDE_W;
constexpr index_t ConvDilationH = CK_PARAM_PROBLEM_CONV_DILATION_H;
constexpr index_t ConvDilationW = CK_PARAM_PROBLEM_CONV_DILATION_W;
constexpr index_t InLeftPadH = CK_PARAM_PROBLEM_IN_LEFT_PAD_H;
constexpr index_t InLeftPadW = CK_PARAM_PROBLEM_IN_LEFT_PAD_W;
constexpr index_t InRightPadH = CK_PARAM_PROBLEM_IN_RIGHT_PAD_H;
constexpr index_t InRightPadW = CK_PARAM_PROBLEM_IN_RIGHT_PAD_W;
constexpr auto CPerGroup = C / G;
constexpr auto in_n_c_hi_wi_desc =
make_native_tensor_descriptor_packed(Sequence<N, C, Hi, Wi>{});
constexpr auto wei_k_cpergroup_y_x_desc =
make_native_tensor_descriptor_packed(Sequence<K, CPerGroup, Y, X>{});
constexpr auto out_n_k_ho_wo_desc =
make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
using ConvStrides = Sequence<ConvStrideH, ConvStrideW>;
using ConvDilations = Sequence<ConvDilationH, ConvDilationW>;
using InLeftPads = Sequence<InLeftPadH, InLeftPadW>;
using InRightPads = Sequence<InRightPadH, InRightPadW>;
// read params: tunning parameters
constexpr index_t GemmMPerBlock = CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK;
constexpr index_t GemmNPerBlock = CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK;
constexpr index_t GemmKPerBlock = CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK;
constexpr index_t GemmMPerWave = CK_PARAM_TUNABLE_GEMM_M_PER_WAVE;
constexpr index_t GemmNPerWave = CK_PARAM_TUNABLE_GEMM_N_PER_WAVE;
constexpr index_t GemmKPack = CK_PARAM_TUNABLE_GEMM_KPACK;
// read params: dependent parameters
constexpr index_t BlockSize = CK_PARAM_DEPENDENT_BLOCK_SIZE;
constexpr index_t GridSize = CK_PARAM_DEPENDENT_GRID_SIZE;
// A matrix copy
constexpr index_t GemmABlockCopyClusterLengths_GemmK =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
constexpr index_t GemmABlockCopyClusterLengths_GemmM =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmM =
GemmMPerBlock / GemmABlockCopyClusterLengths_GemmM;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmKPack =
GemmKPack / GemmABlockCopyClusterLengths_GemmKPack;
using GemmABlockCopyClusterLengths_GemmG_GemmK_GemmM_GemmKPack =
Sequence<1,
GemmABlockCopyClusterLengths_GemmK,
GemmABlockCopyClusterLengths_GemmM,
GemmABlockCopyClusterLengths_GemmKPack>;
using GemmABlockCopySubLengths_GemmG_GemmK_GemmM_GemmKPack =
Sequence<1,
GemmABlockCopyThreadSliceLengths_GemmK,
GemmABlockCopyThreadSliceLengths_GemmM,
GemmABlockCopyThreadSliceLengths_GemmKPack>;
using GemmABlockCopyThreadClusterArrangeOrder =
Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack]
using GemmABlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack]
using GemmABlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmM, GemmKPack]
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_KPACK;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
// B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmN =
GemmNPerBlock / GemmBBlockCopyClusterLengths_GemmN;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmKPack =
GemmKPack / GemmBBlockCopyClusterLengths_GemmKPack;
using GemmBBlockCopyClusterLengths_GemmG_GemmK_GemmN_GemmKPack =
Sequence<1,
GemmBBlockCopyClusterLengths_GemmK,
GemmBBlockCopyClusterLengths_GemmN,
GemmBBlockCopyClusterLengths_GemmKPack>;
using GemmBBlockCopySubLengths_GemmG_GemmK_GemmN_GemmKPack =
Sequence<1,
GemmBBlockCopyThreadSliceLengths_GemmK,
GemmBBlockCopyThreadSliceLengths_GemmN,
GemmBBlockCopyThreadSliceLengths_GemmKPack>;
using GemmBBlockCopyThreadClusterArrangeOrder =
Sequence<0, 1, 3, 2>; // [GemmG, GemmK, GemmKPack, GemmN]
using GemmBBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [GemmG, GemmK, GemmKPack, GemmN]
using GemmBBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmN, GemmKPack]
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
// gridwise GEMM
constexpr auto wkgrp_schd_order = NBlock1MBlock0;
constexpr auto gridwise_conv =
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw<
GridSize,
BlockSize,
FLOAT, // Input data type
FLOAT_ACCUM, // Acc data type
FLOAT, // Ouput data type
decltype(in_n_c_hi_wi_desc),
decltype(wei_k_cpergroup_y_x_desc),
decltype(out_n_k_ho_wo_desc),
G,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPack,
GemmABlockCopySubLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyClusterLengths_GemmG_GemmK_GemmM_GemmKPack,
GemmABlockCopyThreadClusterArrangeOrder,
GemmABlockCopySrcAccessOrder,
GemmABlockCopyDstAccessOrder,
GemmABlockCopySrcDataPerRead_GemmKPack,
GemmABlockCopyDstDataPerWrite_GemmKPack,
GemmBBlockCopySubLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyClusterLengths_GemmG_GemmK_GemmN_GemmKPack,
GemmBBlockCopyThreadClusterArrangeOrder,
GemmBBlockCopySrcAccessOrder,
GemmBBlockCopyDstAccessOrder,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPack,
wkgrp_schd_order>{};
gridwise_conv.Run(p_in_global, p_wei_global, p_out_global);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment