Commit 7b002f23 authored by Jing Zhang's avatar Jing Zhang
Browse files

add v4r4 xdlops

parent 87a75734
......@@ -51,702 +51,6 @@ struct make_block_work_sequence<MBlockWork, NBlockWork, NBlock1MBlock0>
__device__ constexpr auto get() { return Sequence<NBlockWork, MBlockWork>{}; }
};
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_K_M_KPACK,
class ABlockCopyThreadClusterLengths_K_M_KPACK,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_KPACK,
class BBlockCopyThreadSliceLengths_K_N_KPACK,
class BBlockCopyThreadClusterLengths_K_N_KPACK,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_KPACK,
InMemoryDataOperation OutputMemOp,
WorkgroupScheduleOrder WorkgroupSchdOrder,
index_t ABlockCopySrcDataStride = 1,
index_t BBlockCopySrcDataStride = 1>
struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__ void Run(const ABFloat* const __restrict__ p_a_global,
const ABFloat* const __restrict__ p_b_global,
CFloat* const __restrict__ p_c_global) const
{
constexpr auto b_k_n_kpack_global_desc = BGlobalDesc{};
constexpr auto a_k_m_kpack_global_desc = AGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto K = b_k_n_kpack_global_desc.GetLengths()[0];
constexpr auto N = b_k_n_kpack_global_desc.GetLengths()[1];
constexpr auto M = a_k_m_kpack_global_desc.GetLengths()[1];
constexpr auto KPACK = b_k_n_kpack_global_desc.GetLengths()[2];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_sequence =
make_block_work_sequence<MBlockWork, NBlockWork, WorkgroupSchdOrder>{}.get();
constexpr auto block_work_desc = make_cluster_descriptor(block_work_sequence);
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[0] * MPerBlock)
: (block_work_id[1] * MPerBlock);
const index_t b_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[1] * NPerBlock)
: (block_work_id[0] * NPerBlock);
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_KPACK,
ABlockCopyDstDataPerWrite_KPACK,
KPACK * GemmDataPerReadM,
KPACK * GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock, KPACK>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_k_m_kpack_global_desc),
decltype(a_k_m_kpack_block_desc),
decltype(a_k_m_kpack_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M_KPACK,
ABlockCopyThreadClusterLengths_K_M_KPACK,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim, // Src dim to be read in vector form (M dimension)
2, // Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set,
ABlockCopySrcDataStride>({0, k_block_data_on_global, 0}, {0, 0, 0});
constexpr auto b_k_n_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock, KPACK>{}, Number<max_align>{});
// input blockwise copy
auto b_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(b_k_n_kpack_global_desc),
decltype(b_k_n_kpack_block_desc),
decltype(b_k_n_kpack_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N_KPACK,
BBlockCopyThreadClusterLengths_K_N_KPACK,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim, // Src dim to be read in vector form (N dimension)
2, // Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set,
BBlockCopySrcDataStride>({0, b_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<MPerBlock>{});
constexpr auto b_k_n_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<NPerBlock>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
ABFloat,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_kpack_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_kpack_block_desc.GetElementSpace(), max_align);
__shared__ ABFloat p_a_block_double[2 * a_block_space];
__shared__ ABFloat p_b_block_double[2 * b_block_space];
// get zero-initialized output register of vector type
auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using blockwise_a_copy_src_step = Sequence<KPerBlock, 0, 0>;
using blockwise_b_copy_src_step = Sequence<KPerBlock, 0, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
ABFloat* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
ABFloat* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
ABFloat* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
ABFloat* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how fp16/bfloat16 datatypes are
// processed in gemm operation. fp16 type packs 4 fp16 values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single fp16 to 4 packed fp16/2 packed bfloat16
// respectively.
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_now);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_now);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double + a_block_space);
p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double + b_block_space);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
}
// copy output: register to global memory
{
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t K0 = OutputLayout.M1();
constexpr index_t K1 = OutputLayout.N1();
constexpr index_t K2 = OutputLayout.M0();
constexpr auto out_k0_k1_k2_b_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<K0, K1, K2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// src descriptor
constexpr auto out_k0_k1_k2_b_thread_desc =
make_native_tensor_descriptor_packed(Sequence<K0, 1, K2, 1>{});
using OutThreadCopySliceLengths = Sequence<K0, 1, K2, 1>;
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc),
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
1,
1,
AddressSpace::Vgpr,
is_same<AccFloat, CFloat>::value ? AddressSpace::Global : AddressSpace::Generic,
OutputMemOp>({0, 0, 0, 0},
{k_thread_data_on_global / (K2 * K1),
k_thread_data_on_global % (K2 * K1) / K2,
k_thread_data_on_global % K2,
b_thread_data_on_global})
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
}
}
};
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class ABlockCopyThreadSliceLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_KPACK,
class BBlockCopyThreadSliceLengths_G_K_N_KPACK,
class BBlockCopyThreadClusterLengths_G_K_N_KPACK,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_KPACK,
InMemoryDataOperation OutputMemOp,
WorkgroupScheduleOrder WorkgroupSchdOrder,
index_t ABlockCopySrcDataStride = 1,
index_t BBlockCopySrcDataStride = 1>
struct GridwiseBatchedGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__ void Run(const ABFloat* const __restrict__ p_a_global,
const ABFloat* const __restrict__ p_b_global,
CFloat* const __restrict__ p_c_global) const
{
constexpr auto a_g_k_m_kpack_global_desc = AGlobalDesc{};
constexpr auto b_g_k_n_kpack_global_desc = BGlobalDesc{};
constexpr auto c_g_m_n_global_desc = CGlobalDesc{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto Gi = b_g_k_n_kpack_global_desc.GetLengths()[0];
constexpr auto Go = c_g_m_n_global_desc.GetLengths()[0];
constexpr auto K = b_g_k_n_kpack_global_desc.GetLengths()[1];
constexpr auto N = b_g_k_n_kpack_global_desc.GetLengths()[2];
constexpr auto M = a_g_k_m_kpack_global_desc.GetLengths()[2];
constexpr auto KPACK = b_g_k_n_kpack_global_desc.GetLengths()[3];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr index_t MWaves = MPerBlock / MPerWave;
constexpr index_t NWaves = NPerBlock / NPerWave;
constexpr auto block_work_sequence =
make_batch_block_work_sequence<Gi, MBlockWork, NBlockWork, WorkgroupSchdOrder>{}.get();
constexpr auto block_work_desc = make_cluster_descriptor(block_work_sequence);
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t group_id = block_work_id[0];
const index_t m_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[1] * MPerBlock)
: (block_work_id[2] * MPerBlock);
const index_t n_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[2] * NPerBlock)
: (block_work_id[1] * NPerBlock);
// LDS mem
constexpr index_t max_align = math::lcm(BBlockCopyDstDataPerWrite_KPACK,
ABlockCopyDstDataPerWrite_KPACK,
KPACK * GemmDataPerReadM,
KPACK * GemmDataPerReadN);
// LDS
// be careful of LDS alignment
constexpr auto a_g_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, MPerBlock, KPACK>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_g_k_m_kpack_global_desc),
decltype(a_g_k_m_kpack_block_desc),
decltype(a_g_k_m_kpack_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK,
ABlockCopyThreadClusterLengths_G_K_M_KPACK,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim, // Src dim to be read in vector form (K dimension)
3, // Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set,
ABlockCopySrcDataStride>({group_id, 0, m_block_data_on_global, 0}, {0, 0, 0, 0});
constexpr auto b_g_k_n_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, NPerBlock, KPACK>{}, Number<max_align>{});
// input blockwise copy
auto b_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(b_g_k_n_kpack_global_desc),
decltype(b_g_k_n_kpack_block_desc),
decltype(b_g_k_n_kpack_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_G_K_N_KPACK,
BBlockCopyThreadClusterLengths_G_K_N_KPACK,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim, // Src dim to be read in vector form (K dimension)
3, // Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead, // N dimension
BBlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set,
BBlockCopySrcDataStride>({group_id, 0, n_block_data_on_global, 0}, {0, 0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<MPerBlock>{});
constexpr auto b_k_n_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<NPerBlock>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
ABFloat,
MPerWave,
NPerWave,
MWaves,
NWaves,
GemmDataPerReadM,
GemmDataPerReadN>{};
constexpr index_t a_block_space =
math::integer_least_multiple(a_g_k_m_kpack_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_g_k_n_kpack_block_desc.GetElementSpace(), max_align);
__shared__ ABFloat p_a_block_double[2 * a_block_space];
__shared__ ABFloat p_b_block_double[2 * b_block_space];
// get zero-initialized output register of vector type
auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
using blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>;
using blockwise_b_copy_src_step = Sequence<0, KPerBlock, 0, 0>;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
ABFloat* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
ABFloat* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
ABFloat* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
ABFloat* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how fp16/bfloat16 datatypes are
// processed in gemm operation. fp16 type packs 4 fp16 values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single fp16 to 4 packed fp16/2 packed bfloat16
// respectively.
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_now);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_now);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step{}, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on current data
p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double + a_block_space);
p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double + b_block_space);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
const typename vector_type<ABFloat, KPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_a_block_double);
const typename vector_type<ABFloat, KPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPACK>::MemoryType*>(
p_b_block_double);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_g_m0_m1_m2_n_global_desc = transform_tensor_descriptor(
c_g_m_n_global_desc,
make_tuple(PassThrough<Go>{}, UnMerge<Sequence<M0, M1, M2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
// src descriptor
constexpr auto c_g_m0_m1_m2_n_thread_desc =
make_native_tensor_descriptor_packed(Sequence<1, M0, 1, M2, 1>{});
using CThreadCopySliceLengths = Sequence<1, M0, 1, M2, 1>;
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
OutputMemOp>(
{0, 0, 0, 0, 0},
{group_id,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
}
}
};
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
......@@ -812,13 +116,13 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t g_block_data_on_global = block_work_id[0];
const index_t g_block_data_on_global = block_work_id[Number<0>{}];
const index_t m_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[1] * MPerBlock)
: (block_work_id[2] * MPerBlock);
? (block_work_id[Number<1>{}] * MPerBlock)
: (block_work_id[Number<2>{}] * MPerBlock);
const index_t n_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[2] * NPerBlock)
: (block_work_id[1] * NPerBlock);
? (block_work_id[Number<2>{}] * NPerBlock)
: (block_work_id[Number<1>{}] * NPerBlock);
constexpr index_t max_align = KPack;
......@@ -826,7 +130,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr auto a_g_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, MPerBlock, KPack>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v5<
BlockSize,
decltype(a_g_k_m_kpack_global_desc),
decltype(a_g_k_m_kpack_block_desc),
......@@ -843,8 +147,9 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({g_block_data_on_global, 0, m_block_data_on_global, 0},
{0, 0, 0, 0});
InMemoryDataOperation::Set>(
make_multi_index(g_block_data_on_global, 0, m_block_data_on_global, 0),
make_multi_index(0, 0, 0, 0));
constexpr auto b_g_k_n_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, NPerBlock, KPack>{}, Number<max_align>{});
......@@ -867,8 +172,9 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({g_block_data_on_global, 0, n_block_data_on_global, 0},
{0, 0, 0, 0});
InMemoryDataOperation::Set>(
make_multi_index(g_block_data_on_global, 0, n_block_data_on_global, 0),
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -918,14 +224,11 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += KPerBlock)
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
// ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
// load next data from device mem
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
block_sync_lds();
......@@ -943,7 +246,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
block_sync_lds();
// store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block);
a_blockwise_copy.RunStoreThreadBuffer(p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(p_b_block);
}
......@@ -1015,298 +318,12 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryOp>(
{0, 0, 0, 0, 0},
{g_block_data_on_global,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
}
}
};
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class AGlobalDesc,
class BGlobalDesc,
class CGlobalDesc,
index_t MPerBlock,
index_t BPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t BPerWave,
class ABlockCopyThreadSliceLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterLengths_G_K_M_KPACK,
class ABlockCopyThreadClusterArrangeOrder,
class ABlockCopySrcAccessOrder,
class ABlockCopyDstAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_KPACK,
class BBlockCopyThreadSliceLengths_G_K_N1_B_KPack,
class BBlockCopyThreadClusterLengths_G_K_N1_B_KPack,
class BBlockCopyThreadClusterArrangeOrder,
class BBlockCopySrcAccessOrder,
class BBlockCopyDstAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_KPACK,
InMemoryDataOperation CGlobalMemoryOp,
WorkgroupScheduleOrder WorkgroupSchdOrder>
struct GridwiseBatchGemmXdlops_gkmkpack_gkn1bkpack_gmn_v2
{
__device__ void Run(const ABFloat* const __restrict__ p_a_global,
const ABFloat* const __restrict__ p_b_global,
CFloat* const __restrict__ p_c_global) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_g_k_m_kpack_global_desc = AGlobalDesc{};
constexpr auto b_g_k_n1_b_kpack_global_desc = BGlobalDesc{};
constexpr auto c_g_m_n_global_desc = CGlobalDesc{};
constexpr auto G = c_g_m_n_global_desc.GetLengths()[0];
constexpr auto M = c_g_m_n_global_desc.GetLengths()[1];
constexpr auto N = c_g_m_n_global_desc.GetLengths()[2];
constexpr auto K = b_g_k_n1_b_kpack_global_desc.GetLengths()[1];
constexpr auto in_N1 = b_g_k_n1_b_kpack_global_desc.GetLengths()[2];
constexpr auto B = b_g_k_n1_b_kpack_global_desc.GetLengths()[3];
constexpr auto KPack = b_g_k_n1_b_kpack_global_desc.GetLengths()[4];
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && B % BPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr index_t MWavePerBlock = MPerBlock / MPerWave;
constexpr index_t BWavePerBlock = in_N1;
static_assert((G * MBlockWork * BBlockWork) == GridSize, "Invalid GridSize");
constexpr auto block_work_sequence =
make_batch_block_work_sequence<G, MBlockWork, BBlockWork, WorkgroupSchdOrder>{}.get();
constexpr auto block_work_desc = make_cluster_descriptor(block_work_sequence);
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t g_block_data_on_global = block_work_id[0];
const index_t m_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[1] * MPerBlock)
: (block_work_id[2] * MPerBlock);
const index_t b_block_data_on_global = (WorkgroupSchdOrder == MBlock1NBlock0)
? (block_work_id[2] * BPerBlock)
: (block_work_id[1] * BPerBlock);
constexpr index_t max_align = KPack;
// LDS be careful of LDS alignment
constexpr auto a_g_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, MPerBlock, KPack>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_g_k_m_kpack_global_desc),
decltype(a_g_k_m_kpack_block_desc),
decltype(a_g_k_m_kpack_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK,
ABlockCopyThreadClusterLengths_G_K_M_KPACK,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
ABlockCopyDstAccessOrder,
ABlockCopySrcVectorReadDim, // Src dim to be read in vector form
3, // Dst dim to be written in vector form (KPack dimension)
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({g_block_data_on_global, 0, m_block_data_on_global, 0},
{0, 0, 0, 0});
constexpr auto b_g_k_n1_b_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, in_N1, BPerBlock, KPack>{}, Number<max_align>{});
// input blockwise copy
auto b_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(b_g_k_n1_b_kpack_global_desc),
decltype(b_g_k_n1_b_kpack_block_desc),
decltype(b_g_k_n1_b_kpack_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_G_K_N1_B_KPack,
BBlockCopyThreadClusterLengths_G_K_N1_B_KPack,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
BBlockCopyDstAccessOrder,
BBlockCopySrcVectorReadDim, // Src dim to be read in vector form
4, // Dst dim to be written in vector form (KPack dimension)
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_KPACK,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>({g_block_data_on_global, 0, 0, b_block_data_on_global, 0},
{0, 0, 0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, BPerBlock * in_N1] is in LDS
// c_mtx[MPerBlock, BPerBlock * in_N1] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<MPerBlock>{});
constexpr auto b_k_n_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<KPerBlock>{}, Number<BPerBlock * in_N1>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
ABFloat,
MPerWave,
BPerWave,
MWavePerBlock,
BWavePerBlock,
1,
1>{};
constexpr index_t a_block_space =
math::integer_least_multiple(a_g_k_m_kpack_block_desc.GetElementSpace(), max_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_g_k_n1_b_kpack_block_desc.GetElementSpace(), max_align);
__shared__ ABFloat p_a_block[a_block_space];
__shared__ ABFloat p_b_block[b_block_space];
// get zero-initialized output register of vector type
auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
// preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block);
b_blockwise_copy.Run(p_b_global, p_b_block);
}
constexpr auto blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>{};
constexpr auto blockwise_b_copy_src_step = Sequence<0, KPerBlock, 0, 0, 0>{};
// main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += KPerBlock)
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
// load next data from device mem
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
block_sync_lds();
// GEMM on current data
const typename vector_type<ABFloat, KPack>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_a_block);
const typename vector_type<ABFloat, KPack>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_b_block);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
block_sync_lds();
// store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block);
}
// tail
{
block_sync_lds();
// GEMM on last data
const typename vector_type<ABFloat, KPack>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_a_block);
const typename vector_type<ABFloat, KPack>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<ABFloat, KPack>::MemoryType*>(
p_b_block);
c_thread_vec = blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, c_thread_vec);
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr auto CLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_g_m0_m1_m2_n_global_desc = transform_tensor_descriptor(
c_g_m_n_global_desc,
make_tuple(
PassThrough<G>{}, UnMerge<Sequence<M / (M1 * M2), M1, M2>>{}, PassThrough<N>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
// src descriptor
constexpr auto c_g_m0_m1_m2_n_thread_desc =
make_native_tensor_descriptor_packed(Sequence<1, M0, 1, M2, 1>{});
using CThreadCopySliceLengths = Sequence<1, M0, 1, M2, 1>;
constexpr index_t BlkSize = blockwise_gemm.GetBlkSize();
constexpr index_t NumBlks = blockwise_gemm.GetNumBlks();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.template GetBeginOfThreadMatrixC<MPerWave, B>(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryOp>(
{0, 0, 0, 0, 0},
{g_block_data_on_global,
make_multi_index(0, 0, 0, 0, 0),
make_multi_index(g_block_data_on_global,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global})
n_thread_data_on_global))
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
}
......
......@@ -27,9 +27,6 @@
#include "amd_inline_asm.hpp"
#endif
#if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp"
#include "amd_xdlops_inline_asm.hpp"
#endif
#endif
#include "common_header.hpp"
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "float_types.h"
template <class T,
class InDesc,
......@@ -10,8 +12,8 @@ template <class T,
class ConvDilations,
class InLeftPads,
class InRightPads>
void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(InDesc,
void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
......@@ -25,29 +27,32 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
{
using namespace ck;
// read params: problem description
constexpr index_t G = CK_PARAM_PROBLEM_G;
constexpr index_t N = CK_PARAM_PROBLEM_N;
constexpr index_t K = CK_PARAM_PROBLEM_K;
constexpr index_t C = CK_PARAM_PROBLEM_C;
constexpr index_t Hi = CK_PARAM_PROBLEM_HI;
constexpr index_t Wi = CK_PARAM_PROBLEM_WI;
constexpr index_t Ho = CK_PARAM_PROBLEM_HO;
constexpr index_t Wo = CK_PARAM_PROBLEM_WO;
constexpr index_t Y = CK_PARAM_PROBLEM_Y;
constexpr index_t X = CK_PARAM_PROBLEM_X;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc =
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr auto wei_kcyx_desc =
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr auto out_nkhw_desc =
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
constexpr index_t ConvStrideH = CK_PARAM_PROBLEM_CONV_STRIDE_H;
constexpr index_t ConvStrideW = CK_PARAM_PROBLEM_CONV_STRIDE_W;
// read params: problem description
constexpr index_t G = 1;
constexpr index_t ConvDilationH = CK_PARAM_PROBLEM_CONV_DILATION_H;
constexpr index_t ConvDilationW = CK_PARAM_PROBLEM_CONV_DILATION_W;
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t InLeftPadH = CK_PARAM_PROBLEM_IN_LEFT_PAD_H;
constexpr index_t InLeftPadW = CK_PARAM_PROBLEM_IN_LEFT_PAD_W;
constexpr index_t C = in_nchw_desc.GetLength(I1);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t InRightPadH = CK_PARAM_PROBLEM_IN_RIGHT_PAD_H;
constexpr index_t InRightPadW = CK_PARAM_PROBLEM_IN_RIGHT_PAD_W;
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr auto CPerGroup = C / G;
......@@ -58,31 +63,27 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr auto out_n_k_ho_wo_desc =
make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
using ConvStrides = Sequence<ConvStrideH, ConvStrideW>;
using ConvDilations = Sequence<ConvDilationH, ConvDilationW>;
using InLeftPads = Sequence<InLeftPadH, InLeftPadW>;
using InRightPads = Sequence<InRightPadH, InRightPadW>;
// read params: tunning parameters
constexpr index_t GemmMPerBlock = CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK;
constexpr index_t GemmNPerBlock = CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK;
constexpr index_t GemmKPerBlock = CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK;
constexpr index_t GemmMPerWave = CK_PARAM_TUNABLE_GEMM_M_PER_WAVE;
constexpr index_t GemmNPerWave = CK_PARAM_TUNABLE_GEMM_N_PER_WAVE;
constexpr index_t GemmKPack = CK_PARAM_TUNABLE_GEMM_KPACK;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 1;
// read params: dependent parameters
constexpr index_t BlockSize = CK_PARAM_DEPENDENT_BLOCK_SIZE;
constexpr index_t GridSize = CK_PARAM_DEPENDENT_GRID_SIZE;
constexpr index_t BlockSize = 256;
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
// A matrix copy
constexpr index_t GemmABlockCopyClusterLengths_GemmK =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
constexpr index_t GemmABlockCopyClusterLengths_GemmM =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmABlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = 64;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK;
......@@ -107,19 +108,13 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
using GemmABlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack]
using GemmABlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmM, GemmKPack]
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_KPACK;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack = 1;
// B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK;
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 64;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK;
......@@ -144,22 +139,20 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
using GemmBBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [GemmG, GemmK, GemmKPack, GemmN]
using GemmBBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmN, GemmKPack]
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack =
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack = 1;
// gridwise GEMM
constexpr auto wkgrp_schd_order = NBlock1MBlock0;
constexpr auto gridwise_conv =
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw<
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw<
GridSize,
BlockSize,
FLOAT, // Input data type
FLOAT_ACCUM, // Acc data type
FLOAT, // Ouput data type
TDevice, // Input data type
TDevice, // Acc data type
TDevice, // Ouput data type
decltype(in_n_c_hi_wi_desc),
decltype(wei_k_cpergroup_y_x_desc),
decltype(out_n_k_ho_wo_desc),
......@@ -188,6 +181,48 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmBBlockCopyDstAccessOrder,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPack,
wkgrp_schd_order>{};
gridwise_conv.Run(p_in_global, p_wei_global, p_out_global);
wkgrp_schd_order>;
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<gridwise_conv,
const TDevice* const __restrict__,
const TDevice* const __restrict__,
TDevice* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
......@@ -13,6 +13,7 @@
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dummy_static_transform.hpp"
#include "device_dummy_dynamic_transform_v1.hpp"
......@@ -111,7 +112,7 @@ int main(int argc, char* argv[])
RightPads{},
nrepeat);
#elif 1
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment