Commit 82a15a27 authored by Jing Zhang's avatar Jing Zhang
Browse files

add xdlops emulation on v100

parent e69b1970
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
class LeftPads,
class RightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
class GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
class GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
class GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
class GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GemmM = K;
constexpr index_t GemmK = C * Y * X;
constexpr index_t GemmN = N * Ho * Wo;
static_assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0,
"wrong! cannot divide work evenly among block");
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// input tensor
// global mem
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalCXdlops_v1<
GridSize,
BlockSize,
Float,
AccDataType,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm_gemmn_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
Sequence<0, 1>,
0,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
InMemoryDataOperation::Set>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class Float,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct MatrixIndex
{
index_t row;
index_t col;
};
//static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB;
static constexpr index_t WaveSize = 64;
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.GetOutputLayout(); }
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops()
{
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
static_assert(GemmMPerWave * GemmMWaves == M, "GemmMWaves * GemmMPerWave != M");
static_assert(GemmNPerWave * GemmNWaves == N, "GemmNWaves * GemmNPerWave != N");
static_assert(BlockSize == GemmMWaves * GemmNWaves * WaveSize,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize\n");
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const index_t waveId_m = waveId / GemmNWaves;
const index_t waveId_n = waveId % GemmNWaves;
mMyWaveOffsetA = waveId_m * GemmMPerWave;
mMyWaveOffsetB = waveId_n * GemmNPerWave;
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow();
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.template Run<M, N, K>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.GetBeginOfThreadBlk(i);
const index_t col = waveId % GemmNWaves * GemmNPerWave + thread_mtx_on_blk.col;
const index_t row = waveId / GemmNWaves * GemmMPerWave + thread_mtx_on_blk.row;
return MatrixIndex{row, col};
}
__device__ constexpr auto GetThreadMatrixCDescriptor() const
{
const index_t reg_size = GemmMPerWave * GemmNPerWave / WaveSize;
return make_ConstantMatrixDescriptor_packed(Number<reg_size>{}, Number<1>{});
}
__device__ void XdlopsMatrixCSetZero() const
{
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.SetZeroXdlopsRegs(Number<thread_mtx_size>{});
}
template <class FloatC>
__device__ void XdlopsMatrixCRead(FloatC* __restrict__ p_c_thread) const
{
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.ReadXdlopsRegs(Number<thread_mtx_size>{}, p_c_thread);
}
};
} // namespace ck
#endif
...@@ -56,8 +56,10 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -56,8 +56,10 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr auto thread_cluster_desc = constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
#if 0
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! BlockSize not consistent with ThreadClusterLengths"); "wrong! BlockSize not consistent with ThreadClusterLengths");
#endif
const auto thread_cluster_id = const auto thread_cluster_id =
thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id()); thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id());
...@@ -83,6 +85,11 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -83,6 +85,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation) if(has_optimized_address_calculation)
{ {
...@@ -93,6 +100,7 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -93,6 +100,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
mThreadwiseLoad.Run(p_block_src, p_thread_buffer); mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
} }
} }
}
template <typename ThreadBufferData, typename BlockDstData> template <typename ThreadBufferData, typename BlockDstData>
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer, __device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
...@@ -101,6 +109,11 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -101,6 +109,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation) if(has_optimized_address_calculation)
{ {
...@@ -111,6 +124,7 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -111,6 +124,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
mThreadwiseStore.Run(p_thread_buffer, p_block_dst); mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
} }
} }
}
template <typename BlockSrcData, typename BlockDstData> template <typename BlockSrcData, typename BlockDstData>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const __device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
......
#ifndef CK_GRIDWISE_GEMM_XDLOPS_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_K_M,
class ABlockCopyThreadClusterLengths_K_M,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_M,
class BBlockCopyThreadSliceLengths_K_N,
class BBlockCopyThreadClusterLengths_K_N,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N,
InMemoryDataOperation CGlobalMemoryDataOperation>
struct GridwiseGemmTransposedANormalBNormalCXdlops_v1
{
__device__ void Run(const Float* const __restrict__ p_a_global,
const Float* const __restrict__ p_b_global,
Float* const __restrict__ p_c_global) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_k_m_global_desc = AGlobalDesc{};
constexpr auto b_k_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K = b_k_n_global_desc.GetLengths()[0];
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
static_assert(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0,
"wrong! M/NPerBlock % M/NPerWave != 0");
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_N,
ABlockCopyDstDataPerWrite_M,
GemmDataPerReadM,
GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_align>{});
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M,
ABlockCopyThreadClusterLengths_K_M,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim,
1,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, m_block_data_on_global}, {0, 0});
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_align>{});
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N,
BBlockCopyThreadClusterLengths_K_N,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim,
1,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, n_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[EPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
Float,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_align);
__shared__ Float p_a_block_double[2 * a_block_space];
__shared__ Float p_b_block_double[2 * b_block_space];
// register allocation for output
AccFloat p_c_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_c_thread);
blockwise_gemm.XdlopsMatrixCSetZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using b_blockwise_copy_src_step = Sequence<KPerBlock, 0>;
using a_blockwise_copy_src_step = Sequence<KPerBlock, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_blockwise_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(b_blockwise_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_blockwise_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(b_blockwise_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// load data from xldop_acc_regs
blockwise_gemm.XdlopsMatrixCRead(p_c_thread);
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<M0, M1, M2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// src descriptor
constexpr auto c_m0_m1_m2_n_thread_desc =
make_native_tensor_descriptor_packed(Sequence<M0, 1, M2, 1>{});
using CThreadCopySliceLengths = Sequence<M0, 1, M2, 1>;
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
.Run(p_c_thread + i * BlkSize, p_c_global);
}
}
}
};
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_G_K_M,
class ABlockCopyThreadClusterLengths_G_K_M,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_M,
class BBlockCopyThreadSliceLengths_G_K_N,
class BBlockCopyThreadClusterLengths_G_K_N,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N,
InMemoryDataOperation CGlobalMemoryDataOperation>
struct GridwiseBatchedGemmTransposedANormalBNormalCXdlops_v1
{
__device__ void Run(const Float* const __restrict__ p_a_global,
const Float* const __restrict__ p_b_global,
Float* const __restrict__ p_c_global) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_g_k_m_global_desc = AGlobalDesc{};
constexpr auto b_g_k_n_global_desc = BGlobalDesc{};
constexpr auto c_g_m_n_global_desc = CGlobalDesc{};
constexpr auto G = b_g_k_n_global_desc.GetLengths()[0];
constexpr auto K = b_g_k_n_global_desc.GetLengths()[1];
constexpr auto N = b_g_k_n_global_desc.GetLengths()[2];
constexpr auto M = a_g_k_m_global_desc.GetLengths()[2];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
static_assert(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0,
"wrong! M/NPerBlock % M/NPerWave != 0");
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<G, MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t group_id = block_work_id[0];
const index_t m_block_data_on_global = block_work_id[1] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[2] * NPerBlock;
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_N,
ABlockCopyDstDataPerWrite_M,
GemmDataPerReadM,
GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_g_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, MPerBlock>{}, Number<max_align>{});
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_g_k_m_global_desc),
decltype(a_g_k_m_block_desc),
decltype(a_g_k_m_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_G_K_M,
ABlockCopyThreadClusterLengths_G_K_M,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim,
2,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{group_id, 0, m_block_data_on_global}, {0, 0, 0});
constexpr auto b_g_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, NPerBlock>{}, Number<max_align>{});
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_g_k_n_global_desc),
decltype(b_g_k_n_block_desc),
decltype(b_g_k_n_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_G_K_N,
BBlockCopyThreadClusterLengths_G_K_N,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim,
2,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{group_id, 0, n_block_data_on_global}, {0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[EPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor_packed(
a_g_k_m_block_desc.GetLength(I1), a_g_k_m_block_desc.GetLength(I2));
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor_packed(
b_g_k_n_block_desc.GetLength(I1), b_g_k_n_block_desc.GetLength(I2));
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
Float,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t a_block_space =
math::integer_least_multiple(a_g_k_m_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_g_k_n_block_desc.GetElementSpace(), max_align);
__shared__ Float p_a_block_double[2 * a_block_space];
__shared__ Float p_b_block_double[2 * b_block_space];
// register allocation for output
AccFloat p_c_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_c_thread);
blockwise_gemm.XdlopsMatrixCSetZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using b_blockwise_copy_src_step = Sequence<0, KPerBlock, 0>;
using a_blockwise_copy_src_step = Sequence<0, KPerBlock, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_blockwise_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(b_blockwise_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_blockwise_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(b_blockwise_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// load data from xldop_acc_regs
blockwise_gemm.XdlopsMatrixCRead(p_c_thread);
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_g_m0_m1_m2_n_global_desc = transform_tensor_descriptor(
c_g_m_n_global_desc,
make_tuple(PassThrough<G>{}, UnMerge<Sequence<M0, M1, M2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
// src descriptor
constexpr auto c_g_m0_m1_m2_n_thread_desc =
make_native_tensor_descriptor_packed(Sequence<1, M0, 1, M2, 1>{});
using CThreadCopySliceLengths = Sequence<1, M0, 1, M2, 1>;
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0, 0},
{group_id,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
.Run(p_c_thread + i * BlkSize, p_c_global);
}
}
}
};
} // namespace ck
#endif
#ifndef CK_XDLOPS_GEMM_HPP
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
namespace ck {
enum struct mfma_instr
{
// fp32
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction
// fp16
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction
// bfp16
mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16,
mfma_f32_32x32x4bf16, // k reduction
mfma_f32_16x16x8bf16, // k reduction
};
template <mfma_instr instr>
struct mfma_info;
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 1;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 64 && NPerWave == 64) || (MPerWave == 32 && NPerWave == 64) ||
(MPerWave == 64 && NPerWave == 32),
"unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float32_t*>(reg_c);
gcnasm_mfma_f32_32x32x1f32<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 32 && NPerWave == 32), "unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_32x32x2f32(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 16), "unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_16x16x4f32(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 1;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 64) || (MPerWave == 64 && NPerWave == 16),
"unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_16x16x1f32<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
// treat 4x4x1 as a single-blk 4x64 mfma
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 1;
static constexpr index_t cycles = 8;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 4 || MPerWave == 8) && NPerWave == 64,
"unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_4x4x1f32<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const
{
static_assert((MPerWave == 64 && NPerWave == 64) || (MPerWave == 32 && NPerWave == 64) ||
(MPerWave == 64 && NPerWave == 32),
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float32_t*>(reg_c);
gcnasm_mfma_f32_32x32x4f16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 8;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const
{
static_assert((MPerWave == 32 && NPerWave == 32), "unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_32x32x8f16(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 16;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 16), "unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_16x16x16f16(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 64) || (MPerWave == 64 && NPerWave == 16),
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_16x16x4f16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 4;
static constexpr index_t cycles = 8;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const
{
static_assert((MPerWave == 4 || MPerWave == 8) && NPerWave == 64,
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_4x4x4f16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 64 && NPerWave == 64) || (MPerWave == 32 && NPerWave == 64) ||
(MPerWave == 64 && NPerWave == 32),
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float32_t*>(reg_c);
gcnasm_mfma_f32_32x32x2bf16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 32 && NPerWave == 32), "unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_32x32x4bf16(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 8;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 16), "unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_16x16x8bf16(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 2;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 64) || (MPerWave == 64 && NPerWave == 16),
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_16x16x2bf16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 2;
static constexpr index_t cycles = 8;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 4 || MPerWave == 8) && NPerWave == 64,
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_4x4x2bf16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <class data_type,
index_t MPerWave,
index_t NPerWave>
__device__ constexpr auto GetMFMAInfo();
template <>
__device__ constexpr auto GetMFMAInfo<float, 32, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 32, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 16, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x4xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 16, 64>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 8, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 4, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half, 32, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half, 32, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x8f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half, 16, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x16f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half, 16, 64>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half, 4, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half, 8, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 32, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 32, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 16, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x8bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 16, 64>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 4, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 8, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{};
}
template <class data_type,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB>
struct XdlopsGemm_t
{
struct MatrixIndex
{
index_t row;
index_t col;
};
template <index_t M1_, index_t M0_, index_t N1_, index_t N0_>
struct OutputLayout
{
__device__ static constexpr index_t M1() { return M1_; }
__device__ static constexpr index_t M0() { return M0_; }
__device__ static constexpr index_t N1() { return N1_; }
__device__ static constexpr index_t N0() { return N0_; }
__device__ static constexpr index_t GetBlkSize() { return GetMFMAInfo<data_type, MPerWave, NPerWave>().num_regs_blk; }
__device__ static constexpr index_t GetNumBlks()
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
return MPerWave * NPerWave / (mfma_type.m * mfma_type.n);
}
};
__device__ constexpr XdlopsGemm_t()
{
static_assert(NPerWave == 4 || NPerWave == 8 || NPerWave == 16 || NPerWave == 32 ||
NPerWave == 64,
"Only support GemmNPerWave == 4, 8, 16, 32 or 64 for xdlops");
static_assert(MPerWave == 4 || MPerWave == 8 || MPerWave == 16 || MPerWave == 32 ||
MPerWave == 64,
"Only support GemmMPerWave == 4, 8, 16, 32 or 64 for xdlops");
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
}
__device__ static constexpr bool IsABroadcast() { return NPerWave >= MPerWave; }
__device__ static constexpr bool IsKReduction()
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
return mfma_type.num_output_blks == 1 && mfma_type.num_input_blks != 1;
}
#if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ void XdlopsEmulate(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC* const __restrict__ p_c_thread) const
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
// K reduction
static_if<IsKReduction()>{}([&](auto) {
for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = (k + n) * M;
index_t b_off = (k + n) * N;
index_t c_off = 0;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex = m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}).Else([&](auto) {
static_if<IsABroadcast()>{}([&](auto) {
// ABroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < MPerWave / mfma_type.m; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + b * mfma_type.m;
index_t b_off = k * N + n * mfma_type.num_threads_blk;
index_t c_off =
n * mfma_type.num_regs_blk + b * mfma_type.num_regs_xdlops;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}
}).Else([&](auto) {
// BBroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < NPerWave / mfma_type.n; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + n * mfma_type.m;
index_t b_off = k * N + b * mfma_type.n;
index_t c_off =
n * mfma_type.num_regs_blk + b * mfma_type.num_regs_xdlops;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}
});
});
}
#endif
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC* const __restrict__ p_c_thread) const
{
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
static_assert(is_same<FloatA, FloatB>::value, "FloatA != FloatB");
static_assert(is_same<FloatC, float>::value, "FloatC != float");
#if CK_USE_AMD_XDLOPS_EMULATE
XdlopsEmulate<M, N, K>(p_a_wave, p_b_wave, p_c_thread);
#else
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
static_if<!IsKReduction()>{}([&](auto) {
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
FloatA a[K];
FloatB b[K];
// load into registers
for(index_t k = 0; k < K; ++k)
{
a[k] = p_a_wave[k * M + laneId];
b[k] = p_b_wave[k * N + laneId];
}
// get pointer of registers
auto pa = reinterpret_cast<const data_type*>(&a);
auto pb = reinterpret_cast<const data_type*>(&b);
for(index_t k = 0; k < K; ++k)
{
constexpr index_t nxdlops = sizeof(FloatA) / (mfma_type.k * sizeof(data_type));
for(index_t i = 0; i < nxdlops; ++i, pa += mfma_type.k, pb += mfma_type.k)
mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, pa, pb, p_c_thread);
}
}).Else([&](auto) {
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
FloatA a[K];
FloatB b[K];
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
// load into registers
for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{
a[k] = p_a_wave[(k + blk_id) * M + blk_td];
b[k] = p_b_wave[(k + blk_id) * N + blk_td];
}
// get pointer of registers
auto pa = reinterpret_cast<const data_type*>(&a);
auto pb = reinterpret_cast<const data_type*>(&b);
constexpr index_t nxdlops =
(sizeof(FloatA) * mfma_type.num_input_blks) / (mfma_type.k * sizeof(data_type));
for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{
for(index_t i = 0; i < nxdlops; ++i, pa += mfma_type.k, pb += mfma_type.k)
mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, pa, pb, p_c_thread);
}
});
#endif
}
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i)
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t col_blk = i % mfma_type.num_output_blks;
index_t row_blk = i / mfma_type.num_output_blks;
index_t col = col_blk * mfma_type.n + blk_td;
index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size;
static_if<!IsABroadcast()>{}([&](auto) {
col_blk = i / mfma_type.num_output_blks;
row_blk = i % mfma_type.num_output_blks;
col = col_blk * mfma_type.n + blk_td;
row = row_blk * mfma_type.m + blk_id * mfma_type.group_size;
});
return MatrixIndex{row, col};
}
__device__ static constexpr auto GetOutputLayout()
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
constexpr auto M1 = mfma_type.num_groups_blk;
constexpr auto M0 = mfma_type.group_size;
constexpr auto N1 = mfma_type.num_input_blks;
constexpr auto N0 = mfma_type.num_threads_blk;
return OutputLayout<M1, M0, N1, N0>{};
}
template <index_t Size>
__device__ void SetZeroXdlopsRegs(Number<Size>) const
{
#if !CK_USE_AMD_XDLOPS_EMULATE
//gcnasm_accvgpr_zero<Size>();
#endif
}
template <index_t Size, class FloatC>
__device__ void ReadXdlopsRegs(Number<Size>, FloatC* const __restrict__ p_c_thread) const
{
#if !CK_USE_AMD_XDLOPS_EMULATE
//constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
//gcnasm_nop<mfma_type.cycles>();
//gcnasm_accvgpr_read<Size>(p_c_thread);
#else
(void)p_c_thread;
#endif
}
};
} // namespace ck
#endif
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
namespace ck {
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x1f32(const float&, const float&, float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 32; i++)
{
reg_c_[i] += reg_a * reg_b;
reg_c_[i+32] = reg_c[i];
}
}
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
{
reg_c_[i] += reg_a * reg_b;
reg_c_[i+16] = reg_c[i];
}
}
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<64, 32>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
{
reg_c_[i] += reg_a * reg_b;
reg_c_[i+16] = reg_c[i];
}
}
__device__ void gcnasm_mfma_f32_32x32x2f32(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
{
reg_c_[i] += reg_a * reg_b;
}
}
__device__ void gcnasm_mfma_f32_16x16x4f32(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x1f32(const float&, const float&, float16_t*);
template <>
__device__ void gcnasm_mfma_f32_16x16x1f32<16, 64>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x1f32<64, 16>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x1f32(const float& reg_a, const float& reg_b, float4_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_4x4x1f32<4, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_4x4x1f32<8, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x4f16(const half4_t&,
const half4_t&,
float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<64, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<32, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<64, 32>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_32x32x8f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_16x16x16f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x4f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_16x16x4f16<16, 64>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x4f16<64, 16>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x4f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_4x4x4f16<4, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_4x4x4f16<8, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x2bf16(const ushort2_t&, const ushort2_t&, float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<32, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 32>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_32x32x4bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_16x16x8bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_4x4x2bf16<4, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_4x4x2bf16<8, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
// clang-format on
}
#endif
...@@ -27,6 +27,9 @@ ...@@ -27,6 +27,9 @@
#if CK_USE_AMD_XDLOPS #if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp" #include "amd_xdlops.hpp"
#else
#include "amd_xdlops_emulate.hpp"
#endif #endif
#endif #endif
...@@ -13,6 +13,19 @@ namespace ck { ...@@ -13,6 +13,19 @@ namespace ck {
using float2_t = float2; using float2_t = float2;
using float4_t = float4; using float4_t = float4;
// float
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// float16
typedef float half4_t __attribute__((ext_vector_type(2)));
typedef float half8_t __attribute__((ext_vector_type(4)));
// bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// float16 // float16
using half2_t = half2; using half2_t = half2;
......
...@@ -522,7 +522,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -522,7 +522,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1 #elif 0
// cdata = 64, BlockSize = 32, 32x64x3 // cdata = 64, BlockSize = 32, 32x64x3
constexpr index_t BlockSize = 32; constexpr index_t BlockSize = 32;
...@@ -559,6 +559,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -559,6 +559,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
// cdata = 64, BlockSize = 64, 32x128x3
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 32;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 3;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 16, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 0
......
...@@ -758,7 +758,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -758,7 +758,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1 #elif 0
// cdata = 64, BlockSize = 32, 32x64x3 // cdata = 64, BlockSize = 32, 32x64x3
constexpr index_t BlockSize = 32; constexpr index_t BlockSize = 32;
...@@ -790,6 +790,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -790,6 +790,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 64, 32x128x3
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 3;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1 #elif 1
// cdata = 64, BlockSize = 64, 64x64x3 // cdata = 64, BlockSize = 64, 64x64x3
......
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc =
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr auto wei_kcyx_desc =
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr auto out_nkhw_desc =
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t ThreadGemmDataPerReadM = 1;
constexpr index_t ThreadGemmDataPerReadN = 1;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN>{};
for(index_t i = 0; i < 10; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
}
// warm up
printf("Warn up running %d times...\n", nrepeat);
for(index_t i = 0; i < nrepeat; ++i)
{
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
printf("Start running %d times...\n", nrepeat);
cudaDeviceSynchronize();
auto start = std::chrono::steady_clock::now();
for(index_t i = 0; i < nrepeat; ++i)
{
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
cudaDeviceSynchronize();
auto end = std::chrono::steady_clock::now();
float ave_time = std::chrono::duration<float, std::milli>(end - start).count() / nrepeat;
printf("Average elapsed time : %f ms, %f TFlop/s\n",
ave_time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time);
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
...@@ -20,495 +20,26 @@ ...@@ -20,495 +20,26 @@
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" //#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0
// 1x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 160;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35
constexpr index_t N = 128;
constexpr index_t C = 96;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 7x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 320;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 0
// 1x7, 17x17
constexpr index_t N = 128;
constexpr index_t C = 224;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 224;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
// 3x3, 299x299 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 299;
constexpr index_t WI = 299;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 149;
constexpr index_t WI = 149;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 17x17, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 35x35
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35, stride 2
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x3, 8x8
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 448;
constexpr index_t Y = 1;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 1>;
using RightPads = Sequence<0, 1>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 1
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 3x3, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14 // 1x1, 14x14
constexpr index_t N = 128; constexpr index_t N = 64;
constexpr index_t C = 1024; constexpr index_t C = 1024;
constexpr index_t HI = 14; constexpr index_t HI = 14;
constexpr index_t WI = 14; constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 28x28
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 14x14
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 1x1, 56x56, stride 2
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x7, 230x230 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 230;
constexpr index_t WI = 230;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride = 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1024; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#endif
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{}); auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
...@@ -603,7 +134,7 @@ int main(int argc, char* argv[]) ...@@ -603,7 +134,7 @@ int main(int argc, char* argv[])
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment