"vscode:/vscode.git/clone" did not exist on "ca3115e7e8e3fbd64fe4ef3c19c84c20fa0c80a9"
Commit 2b8e3ece authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp16

parent 9b4fdeee
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_FP16_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_FP16_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_fp16_bfp16.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
class LeftPads,
class RightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmKPACK,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
class GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK,
class GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK,
index_t GemmABlockCopySrcDataPerRead_GemmKPACK,
index_t GemmABlockCopyDstDataPerWrite_GemmKPACK,
class GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK,
class GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmKPACK>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GemmM = K;
constexpr index_t GemmK = (C * Y * X) / GemmKPACK;
constexpr index_t GemmN = N * Ho * Wo;
static_assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0,
"wrong! cannot divide work evenly among block");
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// input tensor
// global mem
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto in_gemmk_gemmkpack_gemmn_global_desc = transform_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(UnMerge<Sequence<GemmK, GemmKPACK>>{}, PassThrough<GemmN>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
constexpr auto in_gemmk_gemmn_gemmkpack_global_desc = transform_tensor_descriptor(
in_gemmk_gemmkpack_gemmn_global_desc,
make_tuple(PassThrough<GemmK>{}, PassThrough<GemmN>{}, PassThrough<GemmKPACK>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
constexpr auto wei_gemmm_gemmk_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(PassThrough<K>{}, UnMerge<Sequence<GemmK, GemmKPACK>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto wei_gemmk_gemmm_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmm_gemmk_gemmkpack_global_desc,
make_tuple(PassThrough<GemmK>{}, PassThrough<K>{}, PassThrough<GemmKPACK>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1<
GridSize,
BlockSize,
Float,
AccDataType,
Float,
decltype(wei_gemmk_gemmm_gemmkpack_global_desc),
decltype(in_gemmk_gemmn_gemmkpack_global_desc),
decltype(out_gemmm_gemmn_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
Sequence<0, 1, 2>,
2,
GemmABlockCopySrcDataPerRead_GemmKPACK,
GemmABlockCopyDstDataPerWrite_GemmKPACK,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
Sequence<0, 1, 2>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPACK,
InMemoryDataOperation::Set>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_K_M_KPACK,
class ABlockCopyThreadClusterLengths_K_M_KPACK,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_KPACK,
class BBlockCopyThreadSliceLengths_K_N_KPACK,
class BBlockCopyThreadClusterLengths_K_N_KPACK,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_KPACK,
InMemoryDataOperation OutputMemOp>
struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__ void Run(const ABFloat* const __restrict__ p_a_global,
const ABFloat* const __restrict__ p_b_global,
CFloat* const __restrict__ p_c_global) const
{
constexpr auto b_k_n_kpack_global_desc = BGlobalDesc{};
constexpr auto a_k_m_kpack_global_desc = AGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto K = b_k_n_kpack_global_desc.GetLengths()[0];
constexpr auto N = b_k_n_kpack_global_desc.GetLengths()[1];
constexpr auto M = a_k_m_kpack_global_desc.GetLengths()[1];
constexpr auto KPACK = b_k_n_kpack_global_desc.GetLengths()[2];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * NPerBlock;
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_KPACK,
ABlockCopyDstDataPerWrite_KPACK,
KPACK * GemmDataPerReadM,
KPACK * GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock, KPACK>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_k_m_kpack_global_desc),
decltype(a_k_m_kpack_block_desc),
decltype(a_k_m_kpack_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M_KPACK,
ABlockCopyThreadClusterLengths_K_M_KPACK,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim, // Src dim to be read in vector form (M dimension)
2, // Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_KPACK,
AddressSpace::Generic,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({0, k_block_data_on_global, 0}, {0, 0, 0});
constexpr auto b_k_n_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock, KPACK>{}, Number<max_align>{});
// input blockwise copy
auto b_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(b_k_n_kpack_global_desc),
decltype(b_k_n_kpack_block_desc),
decltype(b_k_n_kpack_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N_KPACK,
BBlockCopyThreadClusterLengths_K_N_KPACK,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim, // Src dim to be read in vector form (N dimension)
2, // Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_KPACK,
AddressSpace::Generic,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({0, b_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<MPerBlock>{});
constexpr auto b_k_n_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<NPerBlock>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
ABFloat,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_kpack_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_kpack_block_desc.GetElementSpace(), max_align);
__shared__ ABFloat p_a_block_double[2 * a_block_space];
__shared__ ABFloat p_b_block_double[2 * b_block_space];
// register allocation for output
AccFloat p_c_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_c_thread);
blockwise_gemm.XdlopsMatrixCSetZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using blockwise_a_copy_src_step = Sequence<KPerBlock, 0, 0>;
using blockwise_b_copy_src_step = Sequence<KPerBlock, 0, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
ABFloat* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
ABFloat* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
ABFloat* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
ABFloat* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single half to 4 packed half/2 packed bfloat16
// respectively.
auto p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_now);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_now);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
auto p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_double + a_block_space);
p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double + b_block_space);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
auto p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
}
}
// load data from xldop_acc_regs
blockwise_gemm.XdlopsMatrixCRead(p_c_thread);
// copy output: register to global memory
{
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t K0 = OutputLayout.M1();
constexpr index_t K1 = OutputLayout.N1();
constexpr index_t K2 = OutputLayout.M0();
constexpr auto out_k0_k1_k2_b_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<K0, K1, K2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// src descriptor
constexpr auto out_k0_k1_k2_b_thread_desc =
make_native_tensor_descriptor_packed(Sequence<K0, 1, K2, 1>{});
using OutThreadCopySliceLengths = Sequence<K0, 1, K2, 1>;
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc),
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
1,
1,
AddressSpace::Vgpr,
is_same<AccFloat, CFloat>::value ? AddressSpace::Global : AddressSpace::Generic,
OutputMemOp>({0, 0, 0, 0},
{k_thread_data_on_global / (K2 * K1),
k_thread_data_on_global % (K2 * K1) / K2,
k_thread_data_on_global % K2,
b_thread_data_on_global})
.Run(p_c_thread + i * BlkSize, p_c_global);
}
}
}
};
}
#endif
......@@ -810,7 +810,7 @@ struct XdlopsGemm_t
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
......
......@@ -18,8 +18,6 @@ typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// float16
typedef float half4_t __attribute__((ext_vector_type(2)));
typedef float half8_t __attribute__((ext_vector_type(4)));
// bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
......@@ -28,6 +26,7 @@ typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// float16
using half2_t = half2;
using half4_t = float2;
template <class T, index_t N>
struct vector_type
......@@ -164,6 +163,20 @@ struct inner_product_with_conversion
return acc;
}
__device__ T operator()(half4_t a, half4_t b) const
{
const half* p_a_half = reinterpret_cast<const half*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
};
} // namespace ck
......
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc =
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr auto wei_kcyx_desc =
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr auto out_nkhw_desc =
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmKPACK = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t ThreadGemmDataPerReadM = 1;
constexpr index_t ThreadGemmDataPerReadN = 1;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK = Sequence<1, 4, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK = Sequence<4, 32, 1>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPACK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPACK = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK = Sequence<1, 2, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK = Sequence<4, 32, 1>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPACK = 1;
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw<
GridSize,
BlockSize,
half,
float,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmKPACK,
GemmMPerWave,
GemmNPerWave,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopySrcDataPerRead_GemmKPACK,
GemmABlockCopyDstDataPerWrite_GemmKPACK,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPACK>{};
for(index_t i = 0; i < 10; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
}
// warm up
printf("Warn up running %d times...\n", nrepeat);
for(index_t i = 0; i < nrepeat; ++i)
{
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
printf("Start running %d times...\n", nrepeat);
cudaDeviceSynchronize();
auto start = std::chrono::steady_clock::now();
for(index_t i = 0; i < nrepeat; ++i)
{
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
cudaDeviceSynchronize();
auto end = std::chrono::steady_clock::now();
float ave_time = std::chrono::duration<float, std::milli>(end - start).count() / nrepeat;
printf("Average elapsed time : %f ms, %f TFlop/s\n",
ave_time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time);
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
......@@ -21,6 +21,7 @@
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[])
{
......@@ -524,8 +525,8 @@ int main(int argc, char* argv[])
print_sequence("ConvStrides", ConvStrides{});
print_sequence("ConvDilations", ConvDilations{});
using in_data_t = float;
using out_data_t = float;
using in_data_t = half;
using out_data_t = half;
Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
......@@ -616,7 +617,7 @@ int main(int argc, char* argv[])
RightPads{},
nrepeat);
#elif 1
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(in_nchw_desc,
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment