Commit 506a823a authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent 80901f59
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t N0,
index_t N1,
index_t N2,
index_t Ho0,
index_t Ho1,
index_t Ho2,
index_t Wo0,
index_t Wo1,
index_t Wo2,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
class InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_W2,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 * Ho2 * Wo2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * Ho1 * Wo1 * BPerBlock * N2 * Ho2 * Wo2) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto I7 = Number<7>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N0 * Ho0 * Wo0;
static_assert(N == N0 * N1 * N2 && Ho == Ho0 * Ho1 * Ho2 && Wo == Wo0 * Wo1 * Wo2,
"wrong!");
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_W2 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho0, Ho1, Ho2, Wo0, Wo1, Wo2]
constexpr auto in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc =
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I2, Number<Wo1>{}, Number<Wo2>{})
.Fold(I1, Number<Ho1>{}, Number<Ho2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
constexpr auto in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc =
in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc.ReorderGivenNew2Old(
Sequence<1, 4, 7, 0, 3, 6, 2, 5, 8>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc =
make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc),
Sequence<0, 1, 2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6, 7, 8>{},
Sequence<9>{},
Sequence<10>{},
Sequence<11>{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc =
make_ConstantTensorDescriptor_packed(
Sequence<EPerBlock, N1, Ho1, Wo1, BPerBlock, N2, Ho2, Wo2>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize,
Float,
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc),
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc),
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopyDataPerAccess_W2,
InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0},
{0, 0, 0, 0, 0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
Float,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
#if 0
if(get_block_1d_id() == 0)
{
printf("id (%d %d), in offset: %d %d, wei offset %d %d\n",
get_block_1d_id(),
get_thread_local_1d_id(),
blockwise_in_copy.mThreadSrcOffset,
blockwise_in_copy.mThreadDstOffset,
blockwise_wei_copy.mThreadSrcOffset,
blockwise_wei_copy.mThreadDstOffset);
}
#endif
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetStrides()[3] % GemmDataPerReadB ==
0,
"GemmDataPerReadB alignment requirement is not satisfied");
constexpr auto b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc =
make_ConstantMatrixDescriptor(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<GemmMRepeat * GemmMPerThreadSubC>{},
Number<N1 * Ho1 * Wo1 * N2 * Ho2 * Wo2>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc),
decltype(c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_W2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space = math::integer_least_multiple(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
#if 0
if(get_block_1d_id() == 0)
{
printf("tid (%d %d), %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
p_wei_register_buffer[0],
p_wei_register_buffer[1],
p_wei_register_buffer[2],
p_wei_register_buffer[3]);
}
#endif
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// copy output: register to global memory
{
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc =
make_ConstantTensorDescriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, Ho1, Wo1, 1, 1, 1, N2, Ho2, Wo2>{});
// output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc =
out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc.ReorderGivenNew2Old(
Sequence<6, 3, 9, 0, 1, 2, 7, 4, 10, 8, 5, 11>{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I3, Sequence<Wo1, Wo2>{})
.Fold(I2, Sequence<Ho1, Ho2>{})
.Fold(I1, Sequence<K1, K2>{})
.Fold(I0, Sequence<N1, N2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / (N2 * Ho2 * Wo2);
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc =
make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<1>{},
Sequence<5>{},
Sequence<8>{},
Sequence<0, 4, 7>{},
Sequence<2>{},
Sequence<6>{},
Sequence<9>{});
// origin of dst in device memory
Float* p_out_thread_on_global =
p_out_global +
out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, 0, 0, b_thread_data_on_global, 0, 0, 0);
threadwise_generic_tensor_slice_copy_v1(
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc,
p_out_thread,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc,
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 12, 1>::type{},
Number<1>{});
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_FP16_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_FP16_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_fp16_bfp16.hpp"
namespace ck {
template <index_t GemmKPACK>
struct make_vectorized_WeiDesc_Xdlops
{
template <typename WeiDesc>
__device__ constexpr auto get(WeiDesc&)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto wei_k_c_y_x_global_desc = WeiDesc{};
constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0);
constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
/* kpack comes from c*y*x */
static_assert((C * Y * X) % GemmKPACK == 0,
"C needs to be multiple of vectorized GemmKPACK");
constexpr index_t GemmK = (C * Y * X) / GemmKPACK;
constexpr auto wei_gemmm_gemmk_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3),
make_tuple(PassThrough<K>{}, PassThrough<C * Y * X>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto wei_gemmm_gemmk_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmm_gemmk_global_desc,
make_tuple(PassThrough<K>{}, UnMerge<Sequence<GemmK, GemmKPACK>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto wei_gemmk_gemmm_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmm_gemmk_gemmkpack_global_desc,
make_tuple(PassThrough<GemmK>{}, PassThrough<K>{}, PassThrough<GemmKPACK>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return wei_gemmk_gemmm_gemmkpack_global_desc;
}
};
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
class LeftPads,
class RightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmKPACK,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
class GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK,
class GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK,
index_t GemmABlockCopySrcDataPerRead_GemmKPACK,
index_t GemmABlockCopyDstDataPerWrite_GemmKPACK,
class GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK,
class GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmKPACK>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t GemmM = K;
constexpr index_t GemmK = (C * Y * X) / GemmKPACK;
constexpr index_t GemmN = N * Ho * Wo;
static_assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0,
"wrong! cannot divide work evenly among block");
// sanity-check for vectorized memory load
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto in_gemmk_gemmkpack_gemmn_global_desc = transform_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(UnMerge<Sequence<GemmK, GemmKPACK>>{}, PassThrough<GemmN>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
constexpr auto in_gemmk_gemmn_gemmkpack_global_desc = transform_tensor_descriptor(
in_gemmk_gemmkpack_gemmn_global_desc,
make_tuple(PassThrough<GemmK>{}, PassThrough<GemmN>{}, PassThrough<GemmKPACK>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto wei_gemmk_gemmm_gemmkpack_global_desc =
make_vectorized_WeiDesc_Xdlops<GemmKPACK>{}.get(wei_k_c_y_x_global_desc);
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1<
GridSize,
BlockSize,
Float,
AccDataType,
Float,
decltype(wei_gemmk_gemmm_gemmkpack_global_desc),
decltype(in_gemmk_gemmn_gemmkpack_global_desc),
decltype(out_gemmm_gemmn_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
Sequence<0, 1, 2>,
2,
GemmABlockCopySrcDataPerRead_GemmKPACK,
GemmABlockCopyDstDataPerWrite_GemmKPACK,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
Sequence<0, 1, 2>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPACK,
InMemoryDataOperation::Set>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
class LeftPads,
class RightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
class GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
class GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
class GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
class GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GemmM = K;
constexpr index_t GemmK = C * Y * X;
constexpr index_t GemmN = N * Ho * Wo;
static_assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0,
"wrong! cannot divide work evenly among block");
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// input tensor
// global mem
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalCXdlops_v1<
GridSize,
BlockSize,
Float,
AccDataType,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm_gemmn_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
Sequence<0, 1>,
0,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
InMemoryDataOperation::Set>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class Float,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct MatrixIndex
{
index_t row;
index_t col;
};
// static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave,
// GemmDataPerReadA, GemmDataPerReadB>{};
index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB;
static constexpr index_t WaveSize = 64;
__device__ constexpr auto GetOutputLayout() const
{
return XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.GetOutputLayout();
}
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops()
{
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
static_assert(GemmMPerWave * GemmMWaves == M, "GemmMWaves * GemmMPerWave != M");
static_assert(GemmNPerWave * GemmNWaves == N, "GemmNWaves * GemmNPerWave != N");
static_assert(BlockSize == GemmMWaves * GemmNWaves * WaveSize,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize\n");
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const index_t waveId_m = waveId / GemmNWaves;
const index_t waveId_n = waveId % GemmNWaves;
mMyWaveOffsetA = waveId_m * GemmMPerWave;
mMyWaveOffsetB = waveId_n * GemmNPerWave;
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow();
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.template Run<M, N, K>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk =
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.GetBeginOfThreadBlk(i);
const index_t col = waveId % GemmNWaves * GemmNPerWave + thread_mtx_on_blk.col;
const index_t row = waveId / GemmNWaves * GemmMPerWave + thread_mtx_on_blk.row;
return MatrixIndex{row, col};
}
__device__ constexpr auto GetThreadMatrixCDescriptor() const
{
const index_t reg_size = GemmMPerWave * GemmNPerWave / WaveSize;
return make_ConstantMatrixDescriptor_packed(Number<reg_size>{}, Number<1>{});
}
__device__ void XdlopsMatrixCSetZero() const
{
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.SetZeroXdlopsRegs(Number<thread_mtx_size>{});
}
template <class FloatC>
__device__ void XdlopsMatrixCRead(FloatC* __restrict__ p_c_thread) const
{
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.ReadXdlopsRegs(Number<thread_mtx_size>{}, p_c_thread);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_GEMM_XDLOPS_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_K_M,
class ABlockCopyThreadClusterLengths_K_M,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_M,
class BBlockCopyThreadSliceLengths_K_N,
class BBlockCopyThreadClusterLengths_K_N,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N,
InMemoryDataOperation CGlobalMemoryDataOperation>
struct GridwiseGemmTransposedANormalBNormalCXdlops_v1
{
__device__ void Run(const Float* const __restrict__ p_a_global,
const Float* const __restrict__ p_b_global,
Float* const __restrict__ p_c_global) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_k_m_global_desc = AGlobalDesc{};
constexpr auto b_k_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K = b_k_n_global_desc.GetLengths()[0];
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
static_assert(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0,
"wrong! M/NPerBlock % M/NPerWave != 0");
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_N,
ABlockCopyDstDataPerWrite_M,
GemmDataPerReadM,
GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_align>{});
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M,
ABlockCopyThreadClusterLengths_K_M,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim,
1,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, m_block_data_on_global}, {0, 0});
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_align>{});
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N,
BBlockCopyThreadClusterLengths_K_N,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim,
1,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, n_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[EPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
Float,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_align);
__shared__ Float p_a_block_double[2 * a_block_space];
__shared__ Float p_b_block_double[2 * b_block_space];
// register allocation for output
AccFloat p_c_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_c_thread);
blockwise_gemm.XdlopsMatrixCSetZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using b_blockwise_copy_src_step = Sequence<KPerBlock, 0>;
using a_blockwise_copy_src_step = Sequence<KPerBlock, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_blockwise_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(b_blockwise_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_blockwise_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(b_blockwise_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// load data from xldop_acc_regs
blockwise_gemm.XdlopsMatrixCRead(p_c_thread);
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<M0, M1, M2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// src descriptor
constexpr auto c_m0_m1_m2_n_thread_desc =
make_native_tensor_descriptor_packed(Sequence<M0, 1, M2, 1>{});
using CThreadCopySliceLengths = Sequence<M0, 1, M2, 1>;
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
.Run(p_c_thread + i * BlkSize, p_c_global);
}
}
}
};
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_G_K_M,
class ABlockCopyThreadClusterLengths_G_K_M,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_M,
class BBlockCopyThreadSliceLengths_G_K_N,
class BBlockCopyThreadClusterLengths_G_K_N,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N,
InMemoryDataOperation CGlobalMemoryDataOperation>
struct GridwiseBatchedGemmTransposedANormalBNormalCXdlops_v1
{
__device__ void Run(const Float* const __restrict__ p_a_global,
const Float* const __restrict__ p_b_global,
Float* const __restrict__ p_c_global) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_g_k_m_global_desc = AGlobalDesc{};
constexpr auto b_g_k_n_global_desc = BGlobalDesc{};
constexpr auto c_g_m_n_global_desc = CGlobalDesc{};
constexpr auto G = b_g_k_n_global_desc.GetLengths()[0];
constexpr auto K = b_g_k_n_global_desc.GetLengths()[1];
constexpr auto N = b_g_k_n_global_desc.GetLengths()[2];
constexpr auto M = a_g_k_m_global_desc.GetLengths()[2];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
static_assert(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0,
"wrong! M/NPerBlock % M/NPerWave != 0");
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<G, MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t group_id = block_work_id[0];
const index_t m_block_data_on_global = block_work_id[1] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[2] * NPerBlock;
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_N,
ABlockCopyDstDataPerWrite_M,
GemmDataPerReadM,
GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_g_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, MPerBlock>{}, Number<max_align>{});
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_g_k_m_global_desc),
decltype(a_g_k_m_block_desc),
decltype(a_g_k_m_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_G_K_M,
ABlockCopyThreadClusterLengths_G_K_M,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim,
2,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{group_id, 0, m_block_data_on_global}, {0, 0, 0});
constexpr auto b_g_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, NPerBlock>{}, Number<max_align>{});
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_g_k_n_global_desc),
decltype(b_g_k_n_block_desc),
decltype(b_g_k_n_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_G_K_N,
BBlockCopyThreadClusterLengths_G_K_N,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim,
2,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{group_id, 0, n_block_data_on_global}, {0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[EPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor_packed(
a_g_k_m_block_desc.GetLength(I1), a_g_k_m_block_desc.GetLength(I2));
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor_packed(
b_g_k_n_block_desc.GetLength(I1), b_g_k_n_block_desc.GetLength(I2));
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
Float,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t a_block_space =
math::integer_least_multiple(a_g_k_m_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_g_k_n_block_desc.GetElementSpace(), max_align);
__shared__ Float p_a_block_double[2 * a_block_space];
__shared__ Float p_b_block_double[2 * b_block_space];
// register allocation for output
AccFloat p_c_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_c_thread);
blockwise_gemm.XdlopsMatrixCSetZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using b_blockwise_copy_src_step = Sequence<0, KPerBlock, 0>;
using a_blockwise_copy_src_step = Sequence<0, KPerBlock, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_blockwise_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(b_blockwise_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_blockwise_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(b_blockwise_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// load data from xldop_acc_regs
blockwise_gemm.XdlopsMatrixCRead(p_c_thread);
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_g_m0_m1_m2_n_global_desc = transform_tensor_descriptor(
c_g_m_n_global_desc,
make_tuple(PassThrough<G>{}, UnMerge<Sequence<M0, M1, M2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
// src descriptor
constexpr auto c_g_m0_m1_m2_n_thread_desc =
make_native_tensor_descriptor_packed(Sequence<1, M0, 1, M2, 1>{});
using CThreadCopySliceLengths = Sequence<1, M0, 1, M2, 1>;
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0, 0},
{group_id,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
.Run(p_c_thread + i * BlkSize, p_c_global);
}
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_K_M_KPACK,
class ABlockCopyThreadClusterLengths_K_M_KPACK,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_KPACK,
class BBlockCopyThreadSliceLengths_K_N_KPACK,
class BBlockCopyThreadClusterLengths_K_N_KPACK,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_KPACK,
InMemoryDataOperation OutputMemOp>
struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__ void Run(const ABFloat* const __restrict__ p_a_global,
const ABFloat* const __restrict__ p_b_global,
CFloat* const __restrict__ p_c_global) const
{
constexpr auto b_k_n_kpack_global_desc = BGlobalDesc{};
constexpr auto a_k_m_kpack_global_desc = AGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto K = b_k_n_kpack_global_desc.GetLengths()[0];
constexpr auto N = b_k_n_kpack_global_desc.GetLengths()[1];
constexpr auto M = a_k_m_kpack_global_desc.GetLengths()[1];
constexpr auto KPACK = b_k_n_kpack_global_desc.GetLengths()[2];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * NPerBlock;
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_KPACK,
ABlockCopyDstDataPerWrite_KPACK,
KPACK * GemmDataPerReadM,
KPACK * GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock, KPACK>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_k_m_kpack_global_desc),
decltype(a_k_m_kpack_block_desc),
decltype(a_k_m_kpack_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M_KPACK,
ABlockCopyThreadClusterLengths_K_M_KPACK,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim, // Src dim to be read in vector form (M dimension)
2, // Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_KPACK,
AddressSpace::Generic,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({0, k_block_data_on_global, 0}, {0, 0, 0});
constexpr auto b_k_n_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock, KPACK>{}, Number<max_align>{});
// input blockwise copy
auto b_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(b_k_n_kpack_global_desc),
decltype(b_k_n_kpack_block_desc),
decltype(b_k_n_kpack_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N_KPACK,
BBlockCopyThreadClusterLengths_K_N_KPACK,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim, // Src dim to be read in vector form (N dimension)
2, // Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_KPACK,
AddressSpace::Generic,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({0, b_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<MPerBlock>{});
constexpr auto b_k_n_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<NPerBlock>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
ABFloat,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_kpack_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_kpack_block_desc.GetElementSpace(), max_align);
__shared__ ABFloat p_a_block_double[2 * a_block_space];
__shared__ ABFloat p_b_block_double[2 * b_block_space];
// register allocation for output
AccFloat p_c_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_c_thread);
blockwise_gemm.XdlopsMatrixCSetZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using blockwise_a_copy_src_step = Sequence<KPerBlock, 0, 0>;
using blockwise_b_copy_src_step = Sequence<KPerBlock, 0, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
ABFloat* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
ABFloat* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
ABFloat* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
ABFloat* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single half to 4 packed half/2 packed bfloat16
// respectively.
auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_now);
auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_now);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double + a_block_space);
p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double + b_block_space);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
}
}
// load data from xldop_acc_regs
blockwise_gemm.XdlopsMatrixCRead(p_c_thread);
// copy output: register to global memory
{
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t K0 = OutputLayout.M1();
constexpr index_t K1 = OutputLayout.N1();
constexpr index_t K2 = OutputLayout.M0();
constexpr auto out_k0_k1_k2_b_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<K0, K1, K2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// src descriptor
constexpr auto out_k0_k1_k2_b_thread_desc =
make_native_tensor_descriptor_packed(Sequence<K0, 1, K2, 1>{});
using OutThreadCopySliceLengths = Sequence<K0, 1, K2, 1>;
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc),
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
1,
1,
AddressSpace::Vgpr,
is_same<AccFloat, CFloat>::value ? AddressSpace::Global : AddressSpace::Generic,
OutputMemOp>({0, 0, 0, 0},
{k_thread_data_on_global / (K2 * K1),
k_thread_data_on_global % (K2 * K1) / K2,
k_thread_data_on_global % K2,
b_thread_data_on_global})
.Run(p_c_thread + i * BlkSize, p_c_global);
}
}
}
};
}
#endif
#ifndef CK_XDLOPS_GEMM_HPP
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
#define CK_USE_AMD_XDLOPS_EMULATE 1
namespace ck {
enum struct mfma_instr
{
// fp32
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction
// fp16
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction
// bfp16
mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16,
mfma_f32_32x32x4bf16, // k reduction
mfma_f32_16x16x8bf16, // k reduction
};
template <mfma_instr instr>
struct mfma_info;
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 1;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 64 && NPerWave == 64) || (MPerWave == 32 && NPerWave == 64) ||
(MPerWave == 64 && NPerWave == 32),
"unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float32_t*>(reg_c);
gcnasm_mfma_f32_32x32x1f32<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 32 && NPerWave == 32), "unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_32x32x2f32(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 16), "unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_16x16x4f32(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 1;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 64) || (MPerWave == 64 && NPerWave == 16),
"unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_16x16x1f32<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
// treat 4x4x1 as a single-blk 4x64 mfma
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 1;
static constexpr index_t cycles = 8;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const float* a, const float* b, float* reg_c) const
{
static_assert((MPerWave == 4 || MPerWave == 8) && NPerWave == 64,
"unsupported xdlops gemm");
const auto reg_a = *a;
const auto reg_b = *b;
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_4x4x1f32<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{
static_assert((MPerWave == 64 && NPerWave == 64) || (MPerWave == 32 && NPerWave == 64) ||
(MPerWave == 64 && NPerWave == 32),
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float32_t*>(reg_c);
gcnasm_mfma_f32_32x32x4f16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 8;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{
static_assert((MPerWave == 32 && NPerWave == 32), "unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_32x32x8f16(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 16;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 16), "unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_16x16x16f16(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 64) || (MPerWave == 64 && NPerWave == 16),
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_16x16x4f16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 4;
static constexpr index_t cycles = 8;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{
static_assert((MPerWave == 4 || MPerWave == 8) && NPerWave == 64,
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const half4_t*>(a));
const auto reg_b = *(reinterpret_cast<const half4_t*>(b));
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_4x4x4f16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 64 && NPerWave == 64) || (MPerWave == 32 && NPerWave == 64) ||
(MPerWave == 64 && NPerWave == 32),
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float32_t*>(reg_c);
gcnasm_mfma_f32_32x32x2bf16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 32 && NPerWave == 32), "unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_32x32x4bf16(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 8;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 16), "unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_16x16x8bf16(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 2;
static constexpr index_t cycles = 32;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 16 && NPerWave == 64) || (MPerWave == 64 && NPerWave == 16),
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float16_t*>(reg_c);
gcnasm_mfma_f32_16x16x2bf16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 2;
static constexpr index_t cycles = 8;
template <index_t MPerWave, index_t NPerWave>
__device__ void
run(Number<MPerWave>, Number<NPerWave>, const ushort* a, const ushort* b, float* reg_c) const
{
static_assert((MPerWave == 4 || MPerWave == 8) && NPerWave == 64,
"unsupported xdlops gemm");
const auto reg_a = *(reinterpret_cast<const ushort2_t*>(a));
const auto reg_b = *(reinterpret_cast<const ushort2_t*>(b));
auto reg_c_ = reinterpret_cast<float4_t*>(reg_c);
gcnasm_mfma_f32_4x4x2bf16<MPerWave, NPerWave>(reg_a, reg_b, reg_c_);
}
};
template <class data_type, index_t MPerWave, index_t NPerWave>
__device__ constexpr auto GetMFMAInfo();
template <>
__device__ constexpr auto GetMFMAInfo<float, 32, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 32, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 16, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x4xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 16, 64>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 8, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<float, 4, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half_t, 64, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half_t, 64, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half_t, 32, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half_t, 32, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x8f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half_t, 16, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x16f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half_t, 16, 64>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half_t, 64, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half_t, 4, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<half_t, 8, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 32, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 32, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 16, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x8bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 16, 64>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 4, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{};
}
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 8, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{};
}
template <class data_type,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB>
struct XdlopsGemm_t
{
struct MatrixIndex
{
index_t row;
index_t col;
};
template <index_t M1_, index_t M0_, index_t N1_, index_t N0_>
struct OutputLayout
{
__device__ static constexpr index_t M1() { return M1_; }
__device__ static constexpr index_t M0() { return M0_; }
__device__ static constexpr index_t N1() { return N1_; }
__device__ static constexpr index_t N0() { return N0_; }
__device__ static constexpr index_t GetBlkSize()
{
return GetMFMAInfo<data_type, MPerWave, NPerWave>().num_regs_blk;
}
__device__ static constexpr index_t GetNumBlks()
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
return MPerWave * NPerWave / (mfma_type.m * mfma_type.n);
}
};
__device__ constexpr XdlopsGemm_t()
{
static_assert(NPerWave == 4 || NPerWave == 8 || NPerWave == 16 || NPerWave == 32 ||
NPerWave == 64,
"Only support GemmNPerWave == 4, 8, 16, 32 or 64 for xdlops");
static_assert(MPerWave == 4 || MPerWave == 8 || MPerWave == 16 || MPerWave == 32 ||
MPerWave == 64,
"Only support GemmMPerWave == 4, 8, 16, 32 or 64 for xdlops");
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
}
__device__ static constexpr bool IsABroadcast() { return NPerWave >= MPerWave; }
__device__ static constexpr bool IsKReduction()
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
return mfma_type.num_output_blks == 1 && mfma_type.num_input_blks != 1;
}
#if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ void XdlopsEmulate(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC* const __restrict__ p_c_thread) const
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
// K reduction
static_if<IsKReduction()>{}([&](auto) {
for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = (k + n) * M;
index_t b_off = (k + n) * N;
index_t c_off = 0;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex = m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}).Else([&](auto) {
static_if<IsABroadcast()>{}([&](auto) {
// ABroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < MPerWave / mfma_type.m; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + b * mfma_type.m;
index_t b_off = k * N + n * mfma_type.num_threads_blk;
index_t c_off =
n * mfma_type.num_regs_blk + b * mfma_type.num_regs_xdlops;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}
}).Else([&](auto) {
// BBroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < NPerWave / mfma_type.n; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + n * mfma_type.m;
index_t b_off = k * N + b * mfma_type.n;
index_t c_off =
n * mfma_type.num_regs_blk + b * mfma_type.num_regs_xdlops;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}
});
});
}
#endif
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC* const __restrict__ p_c_thread) const
{
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
static_assert(is_same<FloatA, FloatB>::value, "FloatA != FloatB");
static_assert(is_same<FloatC, float>::value, "FloatC != float");
#if CK_USE_AMD_XDLOPS_EMULATE
XdlopsEmulate<M, N, K>(p_a_wave, p_b_wave, p_c_thread);
#else
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
static_if<!IsKReduction()>{}([&](auto) {
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
for(index_t k = 0; k < K; ++k)
{
mfma_type.run(
Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
}
}).Else([&](auto) {
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{
mfma_type.run(
Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
}
});
#endif
}
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i)
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t col_blk = i % mfma_type.num_output_blks;
index_t row_blk = i / mfma_type.num_output_blks;
index_t col = col_blk * mfma_type.n + blk_td;
index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size;
static_if<!IsABroadcast()>{}([&](auto) {
col_blk = i / mfma_type.num_output_blks;
row_blk = i % mfma_type.num_output_blks;
col = col_blk * mfma_type.n + blk_td;
row = row_blk * mfma_type.m + blk_id * mfma_type.group_size;
});
return MatrixIndex{row, col};
}
__device__ static constexpr auto GetOutputLayout()
{
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
constexpr auto M1 = mfma_type.num_groups_blk;
constexpr auto M0 = mfma_type.group_size;
constexpr auto N1 = mfma_type.num_input_blks;
constexpr auto N0 = mfma_type.num_threads_blk;
return OutputLayout<M1, M0, N1, N0>{};
}
template <index_t Size>
__device__ void SetZeroXdlopsRegs(Number<Size>) const
{
#if !CK_USE_AMD_XDLOPS_EMULATE
// gcnasm_accvgpr_zero<Size>();
#endif
}
template <index_t Size, class FloatC>
__device__ void ReadXdlopsRegs(Number<Size>, FloatC* const __restrict__ p_c_thread) const
{
#if !CK_USE_AMD_XDLOPS_EMULATE
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
// gcnasm_nop<mfma_type.cycles>();
// gcnasm_accvgpr_read<Size>(p_c_thread);
#else
(void)p_c_thread;
#endif
}
};
} // namespace ck
#endif
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
namespace ck {
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x1f32(const float&, const float&, float32_t*);
template <>
__device__ void
gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 32; i++)
{
reg_c_[i + 32] = reg_c_[i] = reg_c_[i] + reg_a * reg_b;
}
}
template <>
__device__ void
gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
{
reg_c_[i] += reg_a * reg_b;
}
}
template <>
__device__ void
gcnasm_mfma_f32_32x32x1f32<64, 32>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
{
reg_c_[i] += reg_a * reg_b;
}
}
__device__ void gcnasm_mfma_f32_32x32x2f32(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
{
reg_c_[i] += reg_a * reg_b;
}
}
__device__ void gcnasm_mfma_f32_16x16x4f32(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x1f32(const float&, const float&, float16_t*);
template <>
__device__ void
gcnasm_mfma_f32_16x16x1f32<16, 64>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void
gcnasm_mfma_f32_16x16x1f32<64, 16>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x1f32(const float& reg_a, const float& reg_b, float4_t* reg_c);
template <>
__device__ void
gcnasm_mfma_f32_4x4x1f32<4, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void
gcnasm_mfma_f32_4x4x1f32<8, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x4f16(const half4_t&, const half4_t&, float32_t*);
template <>
__device__ void
gcnasm_mfma_f32_32x32x4f16<64, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void
gcnasm_mfma_f32_32x32x4f16<32, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void
gcnasm_mfma_f32_32x32x4f16<64, 32>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
__device__ void
gcnasm_mfma_f32_32x32x8f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
__device__ void
gcnasm_mfma_f32_16x16x16f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void
gcnasm_mfma_f32_16x16x4f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c);
template <>
__device__ void
gcnasm_mfma_f32_16x16x4f16<16, 64>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void
gcnasm_mfma_f32_16x16x4f16<64, 16>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void
gcnasm_mfma_f32_4x4x4f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c);
template <>
__device__ void
gcnasm_mfma_f32_4x4x4f16<4, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void
gcnasm_mfma_f32_4x4x4f16<8, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x2bf16(const ushort2_t&, const ushort2_t&, float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 64>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<32, 64>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 32>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float32_t* reg_c)
{
}
__device__ void
gcnasm_mfma_f32_32x32x4bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
{
}
__device__ void
gcnasm_mfma_f32_16x16x8bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void
gcnasm_mfma_f32_16x16x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void
gcnasm_mfma_f32_4x4x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c);
template <>
__device__ void
gcnasm_mfma_f32_4x4x2bf16<4, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void
gcnasm_mfma_f32_4x4x2bf16<8, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
// clang-format on
}
#endif
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_col2im_eb_nchw.hpp"
template <typename T,
typename ColDesc,
typename ImgDesc,
typename FilterSizes,
typename OutputSizes,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads>
void device_col2im_eb_nchw(ColDesc,
const Tensor<T>& col_eb,
ImgDesc,
Tensor<T>& img_nchw,
FilterSizes,
OutputSizes,
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
std::size_t nrepeat)
{
using namespace ck;
constexpr auto col_eb_desc = ColDesc{};
constexpr auto img_nchw_desc = ImgDesc{};
constexpr index_t N = img_nchw_desc.GetLengths()[0];
constexpr index_t C = img_nchw_desc.GetLengths()[1];
constexpr index_t Hi = img_nchw_desc.GetLengths()[2];
constexpr index_t Wi = img_nchw_desc.GetLengths()[3];
constexpr index_t E = col_eb_desc.GetLengths()[0];
constexpr index_t B = col_eb_desc.GetLengths()[1];
std::size_t data_sz = sizeof(T);
DeviceMem col_eb_device_buf(data_sz * col_eb.mDesc.GetElementSpace());
DeviceMem img_nchw_device_buf(data_sz * img_nchw.mDesc.GetElementSpace());
col_eb_device_buf.ToDevice(col_eb.mData.data());
img_nchw_device_buf.ToDevice(img_nchw.mData.data());
#if 1
constexpr index_t BlockSize = 256;
constexpr index_t EPerBlock = 128;
constexpr index_t BPerBlock = 128;
using BlockCopySubLengths_E_B = Sequence<8, 8>;
using BlockCopyClusterLengths_E_B = Sequence<16, 16>;
using BlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using BlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using BlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t BlockCopyDataPerAccess_B = 1;
#endif
constexpr index_t GridSize =
((E + EPerBlock - 1) / EPerBlock) * ((B + BPerBlock - 1) / BPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_col2im = GridwiseCol2Im_eb_nchw<GridSize,
BlockSize,
T,
ColDesc,
ImgDesc,
FilterSizes,
OutputSizes,
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
EPerBlock,
BPerBlock,
BlockCopySubLengths_E_B,
BlockCopyClusterLengths_E_B,
BlockCopyThreadClusterArrangeOrder,
BlockCopySrcAccessOrder,
BlockCopyDstAccessOrder,
BlockCopyDataPerAccess_B>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_col2im),
const T* const __restrict__,
T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
gridwise_col2im,
const_cast<const T* const __restrict__>(
static_cast<T*>(col_eb_device_buf.GetDeviceBuffer())),
const_cast<T* const __restrict__>(
static_cast<T*>(img_nchw_device_buf.GetDeviceBuffer())));
printf("Elapsed time : %f ms\n", time);
usleep(std::min(time * 1000, float(10000)));
}
img_nchw_device_buf.FromDevice(img_nchw.mData.data());
}
......@@ -149,7 +149,7 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw<
using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
......@@ -181,28 +181,38 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
GemmCThreadCopyDstDataPerWrite_GemmN1>;
for(index_t i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < 5; ++i)
{
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
......
......@@ -55,25 +55,27 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
constexpr index_t BPerBlock = 32;
constexpr index_t EPerBlock = 32;
constexpr index_t KPerBlock = 8;
constexpr index_t KPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using OutBlockCopySubLengths_K_B_N0 = Sequence<1, 1, 4>;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using OutBlockCopySubLengths_K_B_N0 = Sequence<2, 1, 4>;
using OutBlockCopyClusterLengths_K_B_N0 = Sequence<8, 32, 1>;
constexpr index_t OutBlockCopySrcDataPerRead_B = 1;
constexpr index_t OutBlockCopyDstDataPerWrite_N0 = 4;
using WeiBlockCopySubLengths_K_E_C0 = Sequence<1, 4, 1>;
using WeiBlockCopySubLengths_K_E_C0 = Sequence<2, 4, 1>;
using WeiBlockCopyClusterLengths_K_E_C0 = Sequence<8, 8, 4>;
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
......@@ -82,8 +84,8 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
constexpr index_t InThreadCopyDstDataPerWrite_B = 1;
#endif
constexpr index_t C0 = GemmMPerThreadSubC;
constexpr index_t N0 = GemmNPerThreadSubC;
constexpr index_t C0 = GemmMPerThread;
constexpr index_t N0 = GemmNPerThread;
constexpr index_t C1 = C / C0;
constexpr index_t N1 = N / N0;
......@@ -96,7 +98,7 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
using gridwise_conv_bwd_data =
GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer<
GridSize,
BlockSize,
......@@ -112,13 +114,13 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
EPerBlock,
BPerBlock,
KPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
OutBlockCopySubLengths_K_B_N0,
......@@ -129,28 +131,38 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
WeiBlockCopyClusterLengths_K_E_C0,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_C0,
InThreadCopyDstDataPerWrite_B>{};
InThreadCopyDstDataPerWrite_B>;
for(index_t i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < 5; ++i)
{
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
......
......@@ -185,7 +185,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw<
using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
......@@ -217,28 +217,38 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
GemmCThreadCopyDstDataPerWrite_GemmN1>;
for(index_t i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < 5; ++i)
{
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
......
......@@ -124,7 +124,7 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw<
using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
......@@ -156,28 +156,38 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
GemmCThreadCopyDstDataPerWrite_GemmN1>;
for(index_t i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < 5; ++i)
{
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
......
......@@ -2,18 +2,13 @@
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename GridwiseOp, index_t GemmId, typename... Xs>
__global__ void run_gridwise_convolution_backward_data_v4r1(Xs... xs)
{
GridwiseOp::template Run<GemmId>(xs...);
}
template <typename T,
typename InDesc,
typename WeiDesc,
......@@ -62,7 +57,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
......@@ -92,36 +87,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#endif
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
......@@ -157,78 +122,82 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < 5; ++i)
{
using GridwiseConvBwdData = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>;
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) {
constexpr index_t gemm_id = decltype(gemm_id_){};
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
constexpr index_t gemm_k = gemm_sizes.At(2);
constexpr bool is_gemm_not_empty = gemm_k > 0;
// only compile and run if GEMM is no empty
static_if<is_gemm_not_empty>{}([&](auto fwd) {
launch_kernel(
run_gridwise_convolution_backward_data_v4r1<GridwiseConvBwdData,
fwd(gemm_id),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
for(index_t i = 0; i < nrepeat; ++i)
{
using GridwiseConvBwdData =
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>;
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
constexpr index_t gemm_k = gemm_sizes.At(2);
constexpr bool is_gemm_not_empty = gemm_k > 0;
// only compile and run if GEMM is no empty
static_if<is_gemm_not_empty>{}([&](auto fwd) {
launch_kernel(run_gridwise_operation<GridwiseConvBwdData,
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__,
decltype(gemm_id)>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()),
fwd(gemm_id));
});
});
});
}
timer.End();
float time = timer.GetElapsedTime();
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
using namespace ck;
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in,
WeiDesc,
const Tensor<T>& wei,
OutDesc,
Tensor<T>& out,
index_t nrepeat)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(data_sz * wei.mDesc.GetElementSpace());
DeviceMem out_device_buf(data_sz * out.mDesc.GetElementSpace());
int num_thread = std::thread::hardware_concurrency();
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
out_device_buf.ToDevice(out.mData.data());
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
#if 1
// 3x3, 34x34, 128 thread
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 1;
constexpr index_t WeiBlockCopyDataPerRead = 1;
constexpr index_t BlockSize = 128;
#endif
constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
using gridwise_conv = GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw<GridSize,
BlockSize,
T,
InDesc,
WeiDesc,
OutDesc,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>;
float time = launch_and_time_kernel(run_gridwise_convolution_kernel<gridwise_conv, T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms\n", time);
usleep(std::min(time * 1000, float(10000)));
}
out_device_buf.FromDevice(out.mData.data());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp"
using namespace ck;
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
std::thread::hardware_concurrency());
// reorder input
auto in_chwn_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_ConstantTensorDescriptor_packed(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0
// for 3x3, 34x34, v1r1, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 1
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
using WeiBlockCopySubLengths_CK = Sequence<2, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<4, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0
// for 3x3, 34x34, v1r1, Vega 20
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>;
constexpr index_t InBlockCopyDataPerAccess_N = 2;
constexpr index_t WeiBlockCopyDataPerAccess_K = 2;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#elif 1
// for 3x3, 34x34, v1r3, Vega 20
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
using WeiBlockCopySubLengths_CK = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#elif 0
// for 3x3, 56x56, v1r1, Pascal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56, v1r2, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 1;
constexpr index_t GemmDataPerReadB = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 28x28, v1r1, Pacal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 28x28, v1r2, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif 0
// for 1x1, 28x28, v1r1, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 14x14, v1r1, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
constexpr index_t BlockSize = 128;
#endif
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_cyxk_desc),
decltype(out_khwn_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerAccess_N,
WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK,
WeiBlockCopyDataPerAccess_K,
OutThreadCopyDataPerAccess_N>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// reorder output
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
};
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
std::thread::hardware_concurrency());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp"
using namespace ck;
template <typename T, class InDesc, class WeiDesc, class OutDesc, class LeftPads, class RightPads>
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
LeftPads,
RightPads,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
std::thread::hardware_concurrency());
// reorder input
auto in_chwn_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
};
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_ConstantTensorDescriptor_packed(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 1
// v1r3, 3x3, 32x32, 1x1 pad
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 8>;
constexpr index_t InBlockCopyDataPerAccess_N = 4;
using WeiBlockCopySubLengths_CK = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#endif
#if 1 // debug
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
#else
constexpr index_t GridSize = 1;
#endif
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_cyxk_desc),
decltype(out_khwn_desc),
LeftPads,
RightPads,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerAccess_N,
WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK,
WeiBlockCopyDataPerAccess_K,
OutThreadCopyDataPerAccess_N>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// reorder output
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
};
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
std::thread::hardware_concurrency());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp"
using namespace ck;
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
std::thread::hardware_concurrency());
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
// for 3x3, 34x34, v1r3, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 16;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<1, 2, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 4, 2, 32>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 4;
#elif 1
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 16;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 2, 16>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 4;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 8;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 4, 8>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 4
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<2, 8, 4, 4>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 2
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<8, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 1
// for 3x3, 28x28, v1r3, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#endif
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#else
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_cyxk_desc),
decltype(out_nkhw_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N,
WeiBlockCopyClusterLengths,
WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_W>{};
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp"
using namespace ck;
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t N = in_nchw_desc.GetLength(I0);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// convert in_nchw to in_cnhw
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
make_ParallelTensorFunctor(
[&](auto n, auto c, auto hi, auto wi) { in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); },
N,
C,
Hi,
Wi)(std::thread::hardware_concurrency());
// convert wei_kcyx to wei_cyxk
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
make_ParallelTensorFunctor(
[&](auto k, auto c, auto y, auto x) { wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); },
K,
C,
Y,
X)(std::thread::hardware_concurrency());
// conver out_nkhw to out_knhw
auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
#if 0
// 3x3, 34x34
// need to use register double buffer for GEMM
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 64 threads
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 64;
#elif 0
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 256 thread
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256;
#elif 0
// 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 1
// 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 256;
#endif
constexpr index_t GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
// mem
std::size_t data_sz = sizeof(T);
DeviceMem in_chwn_device_buf(data_sz * (in_chwn.mDesc.GetElementSpace() + BGhostRead +
BPerBlock)); // reserve extra space for BGhostRead
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
#else
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_chwn_desc),
decltype(wei_cyxk_desc),
decltype(out_khwn_desc),
BPerBlock,
KPerBlock,
CPerBlock,
BPerThread,
KPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead,
OutThreadCopyDataPerWrite>{};
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
// convert out_khwn to out_nkhw
make_ParallelTensorFunctor(
[&](auto n, auto k, auto ho, auto wo) { out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); },
N,
K,
Ho,
Wo)(std::thread::hardware_concurrency());
}
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
};
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
std::thread::hardware_concurrency());
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t N1 = 2;
constexpr index_t N2 = 4;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_C_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_C_N1_B_N2 = Sequence<8, 2, 16, 1>;
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_C_K = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_C_K = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
#endif
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
#else
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_cyxk_desc),
decltype(out_nkhw_desc),
BPerBlock,
KPerBlock,
CPerBlock,
N1,
N2,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_C_N1_B_N2,
InBlockCopyClusterLengths_C_N1_B_N2,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_C_K,
WeiBlockCopyClusterLengths_C_K,
WeiBlockCopyDataPerAccess_K>{};
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment