Unverified Commit 3835318c authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

xdlops_v4r4_fwd fp32/fp16 (#34)



* create files for xdlops

* working on blockwise_gemm_xdlops

* add KReduction

* add m/n repeats

* add 2x2 pipeline

* added 128x128 wavegemm

* use StaticBuffer of vector_type

* break vector type to blk_size

* add kpack into xldops_gemm and blockwise_gemm

* abroadcast only

* add fp32 mfma instructions

* adding fp16 mfma

* pack half4_t

* rename kperwave to kpack

* add 32x32x8fp16

* add fp16 mfma

* clean code

* clean code

* V4r4 xdlops kpack (#35)

* add kpack with incorrect results

* bug fix for make_dynamic_naive_tensor_descriptor_aligned_v2

* add 1x1 kernel

* add gridwise_gemm_v2 - single_buffer

* enabled dwordx4 for fp16
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>

* refactor fwd-v4r4-xdlops

* add v4r4-nhwc-xdlop

* improve some perf of nhwc and nchw by tuning parameters, and change scheuduling in gridwise-gemm loop

* tweak scheduling in gridwise gemm

* add v4r3 with a single output copy

* init commit: output with slice win

* adding sliceWin

* add multiple repeats pattern

* starting adding bwd-v4r1-xdlops

* use tuple as SrcBuffer

* adding bwd-data v4r1 nhwc xdlops

* fix bug in make_dynamic_naive_tensor_descriptor_aligned_v2()

* fix bug in host bwd-data conv

* initial implementation of bwd-data v4r1 nhwc xdlops

* add launch bound flags

* enable launch bound

* add m/nrepeat=4

* tweak bwd-data v4r1 nhwc xdlops

* added bwd-data v4r1 nhwc xlops with output A and weight B

* add fwd-v4r4 nhwc xdlops, A input, B weight, C output
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 1685048a
...@@ -101,6 +101,7 @@ struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1 ...@@ -101,6 +101,7 @@ struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
// GM0 and GN0 need to known at compile-time
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0); static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2); static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
...@@ -140,7 +141,7 @@ struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1 ...@@ -140,7 +141,7 @@ struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value && static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value, is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
"wrong!"); "wrong! GM0 and GN0 need to be known at compile-time");
const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2); const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2);
const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2); const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2);
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k0_m_k1_global_desc,
const BGlobalDesc b_k0_n_k1_global_desc,
const CGlobalDesc c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k0_m_k1_global_desc,
const void __CONSTANT__* p_b_k0_n_k1_global_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m_k1_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k0_m_k1_global_desc);
const auto b_k0_n_k1_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k0_n_k1_global_desc);
const auto c_m0_m1_m2_n_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M_KPack,
typename ABlockTransferThreadClusterLengths_K_M_KPack,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_KPack,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N_KPack,
typename BBlockTransferThreadClusterLengths_K_N_KPack,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_KPack,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = Number<KPack>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_global_desc.GetLength(I0);
const auto M = a_k0_m_k1_global_desc.GetLength(I1);
const auto N = b_k0_n_k1_global_desc.GetLength(I1);
const auto K1 = b_k0_n_k1_global_desc.GetLength(I2);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_global into SGPR
const index_t m_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = Number<KPack>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock, KPack>,
ABlockTransferThreadSliceLengths_K_M_KPack,
ABlockTransferThreadClusterLengths_K_M_KPack,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_global_desc),
decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_KPack,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_global_desc,
make_multi_index(0, m_block_data_idx_on_global, 0),
a_k0_m_k1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock, KPack>,
BBlockTransferThreadSliceLengths_K_N_KPack,
BBlockTransferThreadClusterLengths_K_N_KPack,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_global_desc),
decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_KPack,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_global_desc,
make_multi_index(0, n_block_data_idx_on_global, 0),
b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
KPack>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumXdlops>{}));
constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr,
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
// auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
// c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
// ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
// decltype(c_m0_m1_n0_n1_thread_desc),
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_k0_n_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_space_size, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(
a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
asm volatile("s_nop 0");
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(
a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
asm volatile("s_nop 0");
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K0 - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
auto c_blk = c_thread_buf[Number<c_mr_nr_nx_desc.CalculateOffset(
make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i);
const index_t m_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_m2_n_global_tensor_iterator_hacks);
});
});
});
});
}
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k0_m_k1_global_desc,
const BGlobalDesc b_k0_n_k1_global_desc,
const CGlobalDesc c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k0_m_k1_global_desc,
const void __CONSTANT__* p_b_k0_n_k1_global_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m_k1_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k0_m_k1_global_desc);
const auto b_k0_n_k1_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k0_n_k1_global_desc);
const auto c_m0_m1_m2_n_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M_KPack,
typename ABlockTransferThreadClusterLengths_K_M_KPack,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_KPack,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N_KPack,
typename BBlockTransferThreadClusterLengths_K_N_KPack,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_KPack,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = Number<KPack>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_global_desc.GetLength(I0);
const auto M = a_k0_m_k1_global_desc.GetLength(I1);
const auto N = b_k0_n_k1_global_desc.GetLength(I1);
const auto K1 = b_k0_n_k1_global_desc.GetLength(I2);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_global into SGPR
const index_t m_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = Number<KPack>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock, KPack>,
ABlockTransferThreadSliceLengths_K_M_KPack,
ABlockTransferThreadClusterLengths_K_M_KPack,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_global_desc),
decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_KPack,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_global_desc,
make_multi_index(0, m_block_data_idx_on_global, 0),
a_k0_m_k1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock, KPack>,
BBlockTransferThreadSliceLengths_K_N_KPack,
BBlockTransferThreadClusterLengths_K_N_KPack,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_global_desc),
decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_KPack,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_global_desc,
make_multi_index(0, n_block_data_idx_on_global, 0),
b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
KPack>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumXdlops>{}));
constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr,
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
// register allocation for output
// auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
// c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
// ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
// decltype(c_m0_m1_n0_n1_thread_desc),
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS
{
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
}
// main body
index_t k_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock));
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
auto c_blk = c_thread_buf[Number<c_mr_nr_nx_desc.CalculateOffset(
make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i);
const index_t m_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_m2_n_global_tensor_iterator_hacks);
});
});
});
});
}
}
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc)
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
p_shared_block);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v2r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
// TODO: turn on this
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) &&
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr auto
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1.value>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr auto M0 = Number<CLayout.M1()>{};
constexpr auto M1 = Number<CLayout.N1()>{};
constexpr auto M2 = Number<CLayout.M0()>{};
const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M / (M1 * M2), M1, M2)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
return c_m0_m1_m2_n_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_grid_desc),
decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_k0_m_k1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_grid_desc),
decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
K1.value>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumXdlops>{}));
constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr,
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{};
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack =
AGridMoveSliceWindowIteratorHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS
{
a_blockwise_copy.RunRead(
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
}
// main body
index_t k_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_iterator_hack);
a_blockwise_copy.RunRead(
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
block_sync_lds();
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock));
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
auto c_blk = c_thread_buf[Number<c_mr_nr_nx_desc.CalculateOffset(
make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks =
CGridIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_grid_desc,
make_multi_index(m_thread_data_on_grid / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
});
});
});
});
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks,
bool CAccessOrderMRepeatNRepeat>
struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
// TODO: turn on this
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) &&
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr auto
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1.value>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr auto M0 = Number<CLayout.M1()>{};
constexpr auto M1 = Number<CLayout.N1()>{};
constexpr auto M2 = Number<CLayout.M0()>{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
constexpr auto N0 = Number<CLayout.N1()>{};
constexpr auto N1 = Number<CLayout.N0()>{};
const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return c_m0_m1_m2_n_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
#if 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
#elif 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))),
make_tuple(Sequence<1, 0>{}),
make_tuple(Sequence<0>{}));
#endif
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_grid_desc),
decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_k0_m_k1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_grid_desc),
decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
K1.value>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
StaticBuffer<AddressSpace::Vgpr,
vector_type<FloatAcc, BlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize()>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{};
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack =
AGridMoveSliceWindowIteratorHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS
{
a_blockwise_copy.RunRead(
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
}
// main body
index_t k_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_iterator_hack);
a_blockwise_copy.RunRead(
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
block_sync_lds();
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock));
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#if 0
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr index_t N0 = CLayout.N1();
constexpr index_t N1 = CLayout.N0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<1>{},
Number<1>{},
Number<M0>{},
Number<1>{},
Number<M2>{},
Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i));
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(Number<blk_off * BlkSize + j>{}) =
c_thread_buf[Number<blk_off>{}]
.template AsType<FloatAcc>()[Number<j>{}];
});
});
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<MRepeat, NRepeat, 1, 1, M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_m2_n_grid_desc,
make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves),
n_thread_data_on_grid / (N1 * NWaves),
m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0),
n_thread_data_on_grid % (N1 * NWaves) / N1,
m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid % N1)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
}
#else
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
auto c_thread_copy =
ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_m2_n_grid_desc,
make_multi_index(0,
0,
0,
0,
m_thread_data_on_grid / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid)};
auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
return c_thread_idx_;
};
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
};
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
};
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
};
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
};
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or
(MRepeat == 1 && NRepeat == 1),
"wrong");
if constexpr(MRepeat == 4 && NRepeat == 4)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I2));
nrepeat_plus_copy(make_tuple(I2, I3));
mrepeat_plus_copy(make_tuple(I3, I3));
nrepeat_minus_copy(make_tuple(I3, I2));
nrepeat_minus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I2, I2));
mrepeat_plus_copy(make_tuple(I3, I2));
nrepeat_plus_copy(make_tuple(I3, I3));
mrepeat_minus_copy(make_tuple(I2, I3));
mrepeat_minus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
else if constexpr(MRepeat == 4 && NRepeat == 2)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
mrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
else if constexpr(MRepeat == 2 && NRepeat == 4)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
nrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
else if constexpr(MRepeat == 2 && NRepeat == 2)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
else if constexpr(MRepeat == 2 && NRepeat == 1)
{
init_copy(make_tuple(I0, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
}
else if constexpr(MRepeat == 1 && NRepeat == 2)
{
init_copy(make_tuple(I0, I0));
nrepeat_plus_copy(make_tuple(I0, I1));
}
else if constexpr(MRepeat == 1 && NRepeat == 1)
{
init_copy(make_tuple(I0, I0));
}
}
#endif
}
}; // namespace ck
} // namespace ck
#endif
...@@ -101,9 +101,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -101,9 +101,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, // static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value, // remove_cv_t<remove_reference_t<SrcData>>>::value,
"wrong! SrcBuffer data type is wrong"); //"wrong! SrcBuffer data type is wrong");
// SrcDesc and src_slice_origin_idx are known at compile-time // SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
...@@ -1407,7 +1407,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1407,7 +1407,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
constexpr auto data_to_origin_disp_idx = constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
#endif #endif
// src coordinate // src coordinate
constexpr auto src_ref_to_data_disp_idx = constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx; src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
......
#ifndef CK_XDLOPS_GEMM_HPP
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
#include "amd_xdlops.hpp"
namespace ck {
enum struct mfma_instr
{
/// fp32
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction
/// fp16
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction
/// bfp16
mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16,
mfma_f32_32x32x4bf16, // k reduction
mfma_f32_16x16x8bf16, // k reduction
};
template <mfma_instr instr>
struct mfma_info;
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 1;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x4f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 1;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
// treat 4x4x1 as a single-blk 4x64 mfma
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 1;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 8;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x8f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 16;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x16f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 4;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
#if 0
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 8;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 2;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 2;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
};
#endif
template <mfma_instr instr, index_t MPerXdlops_, index_t NPerXdlops_>
struct xdlops_info
{
static constexpr auto mfma_type = mfma_info<instr>{};
static constexpr index_t MPerXdlops = MPerXdlops_;
static constexpr index_t NPerXdlops = NPerXdlops_;
static constexpr bool IsABroadcast()
{
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
return true;
}
static constexpr bool IsKReduction()
{
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1);
}
static constexpr index_t GetKPerXdlops()
{
return IsKReduction() ? mfma_type.num_input_blks : 1;
}
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
};
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
struct XdlopsGemm
{
template <class base_type_ = base_type,
index_t MPerWave_ = MPerWave,
index_t NPerWave_ = NPerWave>
static constexpr auto GetXdlopsInfo();
template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
}
#endif
using CIndex = MultiIndex<2>;
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
}
__host__ __device__ constexpr XdlopsGemm()
{
static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
NPerXdlops == 64,
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
}
__device__ static constexpr index_t GetRegSizePerXdlops()
{
return MPerXdlops * NPerXdlops / mfma_type.wave_size;
}
template <class ADesc,
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value,
"base base_type must be float, half, ushort!");
static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base");
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[Number<a_offset / mfma_type.k_base>{}],
p_b_wave[Number<b_offset / mfma_type.k_base>{}],
p_c_thread);
});
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t n_offset = blk_i * mfma_type.n + blk_td;
index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size;
return CIndex{m_offset, n_offset};
}
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats;
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr auto GetBlkId(const index_t lane_id)
{
return lane_id / mfma_type.num_threads_blk;
}
static constexpr auto GetBlkTd(const index_t lane_id)
{
return lane_id % mfma_type.num_threads_blk;
}
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
struct CLayout
{
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
__host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; }
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; }
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops /
(mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
}
};
__host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; }
};
} // namespace ck
#endif
...@@ -268,6 +268,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -268,6 +268,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
#if 0
vector_type<half_t, 8> tmp; vector_type<half_t, 8> tmp;
tmp.AsType<half4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4( tmp.AsType<half4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
...@@ -280,6 +281,12 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -280,6 +281,12 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
0); 0);
return tmp.AsType<half8_t>()(Number<0>{}); return tmp.AsType<half8_t>()(Number<0>{});
#else
float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<half8_t>(tmp);
#endif
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
......
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "float_type.hpp"
namespace ck {
// A, B, C, cbsz, abid, blgp
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x1f32;
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x2f32;
template <index_t COffset>
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x4f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x1f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x1f32;
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x4f16;
template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x8f16;
template <index_t COffset>
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x16f16;
template <index_t COffset>
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x4f16;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x4f16;
template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
}
};
#if 0
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
}
};
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x2bf16;
template <>
struct intrin_mfma_f32_4x4x2bf16<4, 64>
{
__device__ static c_vec4_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
}
};
template <>
struct intrin_mfma_f32_4x4x2bf16<8, 64>
{
__device__ static c_vec4_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
}
};
#endif
} // namespace ck
#endif
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 1
#elif 1 #elif 1
#define CK_AMD_GPU_GFX908 1 #define CK_AMD_GPU_GFX908 1
#elif 1 #elif 0
#define CK_AMD_GPU_GFX1030 1 #define CK_AMD_GPU_GFX1030 1
#endif #endif
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#endif #endif
// launch bounds // launch bounds
#define CK_USE_LAUNCH_BOUNDS 0 #define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 256
...@@ -116,7 +116,7 @@ ...@@ -116,7 +116,7 @@
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
// merge transformation use magic number division // merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
......
...@@ -174,8 +174,15 @@ __host__ __device__ constexpr auto container_reduce(const Container& x, ...@@ -174,8 +174,15 @@ __host__ __device__ constexpr auto container_reduce(const Container& x,
{ {
static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
return container_reduce_impl( if constexpr(IEnd > IBegin)
x, reduce, init, Number<IBegin>{}, Number<IEnd>{}, Number<IStep>{}); {
return container_reduce_impl(
x, reduce, init, Number<IBegin>{}, Number<IEnd>{}, Number<IStep>{});
}
else
{
return init;
}
} }
#endif #endif
......
...@@ -618,6 +618,252 @@ struct vector_type<T, 64> ...@@ -618,6 +618,252 @@ struct vector_type<T, 64>
} }
}; };
template <typename T>
struct vector_type<T, 128>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
using type = d128_t;
union
{
d128_t d128_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d2_t, 64> d2x64_;
StaticallyIndexedArray<d4_t, 32> d4x32_;
StaticallyIndexedArray<d8_t, 16> d8x16_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d32_t, 4> d32x4_;
StaticallyIndexedArray<d64_t, 2> d64x2_;
StaticallyIndexedArray<d128_t, 1> d128x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
};
template <typename T>
struct vector_type<T, 256>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d256_t __attribute__((ext_vector_type(256)));
using type = d256_t;
union
{
d256_t d256_;
StaticallyIndexedArray<d1_t, 256> d1x256_;
StaticallyIndexedArray<d2_t, 128> d2x128_;
StaticallyIndexedArray<d4_t, 64> d4x64_;
StaticallyIndexedArray<d8_t, 32> d8x32_;
StaticallyIndexedArray<d16_t, 16> d16x16_;
StaticallyIndexedArray<d32_t, 8> d32x8_;
StaticallyIndexedArray<d64_t, 4> d64x4_;
StaticallyIndexedArray<d128_t, 2> d128x2_;
StaticallyIndexedArray<d256_t, 1> d256x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
};
// fp32 // fp32
using float2_t = typename vector_type<float, 2>::type; using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type; using float4_t = typename vector_type<float, 4>::type;
......
...@@ -9,25 +9,25 @@ ...@@ -9,25 +9,25 @@
namespace ck { namespace ck {
namespace math { namespace math {
template <class T, T s> template <typename T, T s>
struct scales struct scales
{ {
__host__ __device__ constexpr T operator()(T a) const { return s * a; } __host__ __device__ constexpr T operator()(T a) const { return s * a; }
}; };
template <class T> template <typename T>
struct plus struct plus
{ {
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
}; };
template <class T> template <typename T>
struct minus struct minus
{ {
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
}; };
template <class T> template <typename T>
struct multiplies struct multiplies
{ {
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
...@@ -42,81 +42,109 @@ struct multiplies_v2 ...@@ -42,81 +42,109 @@ struct multiplies_v2
} }
}; };
template <class T> template <typename T>
struct maximize struct maximize
{ {
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
}; };
template <class T> template <typename T>
struct minimize struct minimize
{ {
__host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
}; };
template <class T> template <typename T>
struct integer_divide_ceiler struct integer_divide_ceiler
{ {
__host__ __device__ constexpr T operator()(T a, T b) const __host__ __device__ constexpr T operator()(T a, T b) const
{ {
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type"); static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
return (a + b - 1) / b; return (a + b - Number<1>{}) / b;
} }
}; };
template <class X, class Y> template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y) __host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
{ {
return x / y; return x / y;
} }
template <class X, class Y> template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{ {
return (x + y - Number<1>{}) / y; return (x + y - Number<1>{}) / y;
} }
template <class X, class Y> template <typename X, typename Y>
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y) __host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
{ {
return y * integer_divide_ceil(x, y); return y * integer_divide_ceil(x, y);
} }
template <class T> template <typename T>
__host__ __device__ constexpr T max(T x) __host__ __device__ constexpr T max(T x)
{ {
return x; return x;
} }
template <class T, class... Ts> template <typename T>
__host__ __device__ constexpr T max(T x, Ts... xs) __host__ __device__ constexpr T max(T x, T y)
{ {
static_assert(sizeof...(xs) > 0, "not enough argument"); return x > y ? x : y;
}
auto y = max(xs...); template <index_t X>
__host__ __device__ constexpr index_t max(Number<X>, index_t y)
{
return X > y ? X : y;
}
static_assert(is_same<decltype(y), T>{}, "not the same type"); template <index_t Y>
__host__ __device__ constexpr index_t max(index_t x, Number<Y>)
{
return x > Y ? x : Y;
}
return x > y ? x : y; template <typename X, typename... Ys>
__host__ __device__ constexpr auto max(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...));
} }
template <class T> template <typename T>
__host__ __device__ constexpr T min(T x) __host__ __device__ constexpr T min(T x)
{ {
return x; return x;
} }
template <class T, class... Ts> template <typename T>
__host__ __device__ constexpr T min(T x, Ts... xs) __host__ __device__ constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <index_t X>
__host__ __device__ constexpr index_t min(Number<X>, index_t y)
{ {
static_assert(sizeof...(xs) > 0, "not enough argument"); return X < y ? X : y;
}
auto y = min(xs...); template <index_t Y>
__host__ __device__ constexpr index_t min(index_t x, Number<Y>)
{
return x < Y ? x : Y;
}
static_assert(is_same<decltype(y), T>{}, "not the same type"); template <typename X, typename... Ys>
__host__ __device__ constexpr auto min(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return x < y ? x : y; return min(x, min(ys...));
} }
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
...@@ -171,13 +199,13 @@ __host__ __device__ constexpr auto lcm(X x, Ys... ys) ...@@ -171,13 +199,13 @@ __host__ __device__ constexpr auto lcm(X x, Ys... ys)
return lcm(x, lcm(ys...)); return lcm(x, lcm(ys...));
} }
template <class T> template <typename T>
struct equal struct equal
{ {
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; } __host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
}; };
template <class T> template <typename T>
struct less struct less
{ {
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
......
...@@ -153,6 +153,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -153,6 +153,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
return *this; return *this;
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
}; };
template <typename... Xs> template <typename... Xs>
......
...@@ -19,7 +19,22 @@ int main(int argc, char* argv[]) ...@@ -19,7 +19,22 @@ int main(int argc, char* argv[])
{ {
using namespace launcher; using namespace launcher;
#if 0 #if 1
// 1x1 filter, 14x14 image
constexpr index_t N = 1;
constexpr index_t C = 256;
constexpr index_t HI = 1;
constexpr index_t WI = 128;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 56; constexpr index_t HI = 56;
...@@ -93,7 +108,7 @@ int main(int argc, char* argv[]) ...@@ -93,7 +108,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 1x1 filter, 14x14 image // 1x1 filter, 14x14 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
...@@ -153,7 +168,7 @@ int main(int argc, char* argv[]) ...@@ -153,7 +168,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 1 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -245,7 +260,7 @@ int main(int argc, char* argv[]) ...@@ -245,7 +260,7 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 1 #elif 0
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1 #elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
......
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_tensor.hpp"
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_CONV_BWD_V4R1_XDL_NHWC 1
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
enum ConvBackwardDataAlgo
{
V4R1XDLNHWC,
V4R1R2XDLNHWC,
};
int main(int argc, char* argv[])
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
#if USE_DYNAMIC_MODE
// dynamic mode
if(argc != 22)
{
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
exit(1);
}
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(atoi(argv[2]));
const bool do_verification = atoi(argv[3]);
const int init_method = atoi(argv[4]);
const bool do_log = atoi(argv[5]);
const int nrepeat = atoi(argv[6]);
const index_t N = atoi(argv[7]);
const index_t K = atoi(argv[8]);
const index_t C = atoi(argv[9]);
const index_t Y = atoi(argv[10]);
const index_t X = atoi(argv[11]);
const index_t Hi = atoi(argv[12]);
const index_t Wi = atoi(argv[13]);
const index_t conv_stride_h = atoi(argv[14]);
const index_t conv_stride_w = atoi(argv[15]);
const index_t conv_dilation_h = atoi(argv[16]);
const index_t conv_dilation_w = atoi(argv[17]);
const index_t in_left_pad_h = atoi(argv[18]);
const index_t in_left_pad_w = atoi(argv[19]);
const index_t in_right_pad_h = atoi(argv[20]);
const index_t in_right_pad_w = atoi(argv[21]);
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
const index_t XEff = (X - 1) * conv_dilation_w + 1;
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#else
// static mode
if(argc < 7)
{
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
exit(1);
}
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(atoi(argv[2]));
const bool do_verification = atoi(argv[3]);
const int init_method = atoi(argv[4]);
const bool do_log = atoi(argv[5]);
const int nrepeat = atoi(argv[6]);
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t Hi = 71;
constexpr index_t Wi = 71;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
const index_t conv_stride_h = 2;
const index_t conv_stride_w = 2;
const index_t conv_dilation_h = 1;
const index_t conv_dilation_w = 1;
const index_t in_left_pad_h = 1;
const index_t in_left_pad_w = 1;
const index_t in_right_pad_h = 1;
const index_t in_right_pad_w = 1;
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
const index_t XEff = (X - 1) * conv_dilation_w + 1;
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#endif
#if 1
constexpr index_t in_vector_size = 1;
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif 1
constexpr index_t in_vector_size = 1;
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
#endif
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
switch(layout)
{
case ConvTensorLayout::NCHW:
// NCHW
in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(C);
in_lengths_host[2] = static_cast<std::size_t>(Hi);
in_lengths_host[3] = static_cast<std::size_t>(Wi);
wei_lengths_host[0] = static_cast<std::size_t>(K);
wei_lengths_host[1] = static_cast<std::size_t>(C);
wei_lengths_host[2] = static_cast<std::size_t>(Y);
wei_lengths_host[3] = static_cast<std::size_t>(X);
out_lengths_host[0] = static_cast<std::size_t>(N);
out_lengths_host[1] = static_cast<std::size_t>(K);
out_lengths_host[2] = static_cast<std::size_t>(Ho);
out_lengths_host[3] = static_cast<std::size_t>(Wo);
break;
case ConvTensorLayout::NHWC:
// NHWC
in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(Hi);
in_lengths_host[2] = static_cast<std::size_t>(Wi);
in_lengths_host[3] = static_cast<std::size_t>(C);
wei_lengths_host[0] = static_cast<std::size_t>(K);
wei_lengths_host[1] = static_cast<std::size_t>(Y);
wei_lengths_host[2] = static_cast<std::size_t>(X);
wei_lengths_host[3] = static_cast<std::size_t>(C);
out_lengths_host[0] = static_cast<std::size_t>(N);
out_lengths_host[1] = static_cast<std::size_t>(Ho);
out_lengths_host[2] = static_cast<std::size_t>(Wo);
out_lengths_host[3] = static_cast<std::size_t>(K);
break;
default: throw std::runtime_error("wrong! not implemented");
}
Tensor<in_data_t> in_host(in_lengths_host);
Tensor<in_data_t> in_device(in_lengths_host);
Tensor<in_data_t> wei(wei_lengths_host);
Tensor<out_data_t> out(out_lengths_host);
std::cout << "layout: " << layout << std::endl;
ostream_HostTensorDescriptor(in_host.mDesc, std::cout << "in: ");
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: ");
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
std::size_t num_thread = std::thread::hardware_concurrency();
if(do_verification)
{
switch(init_method)
{
case 0:
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 1:
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 2:
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break;
default:
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
}
}
auto f_make_for_device_nchw = [&]() {
#if USE_DYNAMIC_MODE
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
#else
const auto in_lengths_dev =
make_tuple(Number<N>{}, Number<C>{}, Number<Hi>{}, Number<Wi>{});
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<C>{}, Number<Y>{}, Number<X>{});
const auto out_lengths_dev =
make_tuple(Number<N>{}, Number<K>{}, Number<Ho>{}, Number<Wo>{});
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
const auto conv_dilations_dev =
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
const auto in_right_pads_dev =
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
#endif
return make_tuple(in_lengths_dev,
wei_lengths_dev,
out_lengths_dev,
conv_strides_dev,
conv_dilations_dev,
in_left_pads_dev,
in_right_pads_dev);
};
auto f_make_for_device_nhwc = [&]() {
#if USE_DYNAMIC_MODE
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
#else
const auto in_lengths_dev =
make_tuple(Number<N>{}, Number<Hi>{}, Number<Wi>{}, Number<C>{});
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<Y>{}, Number<X>{}, Number<C>{});
const auto out_lengths_dev =
make_tuple(Number<N>{}, Number<Ho>{}, Number<Wo>{}, Number<K>{});
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
const auto conv_dilations_dev =
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
const auto in_right_pads_dev =
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
#endif
return make_tuple(in_lengths_dev,
wei_lengths_dev,
out_lengths_dev,
conv_strides_dev,
conv_dilations_dev,
in_left_pads_dev,
in_right_pads_dev);
};
const auto nhwc_desc = f_make_for_device_nhwc();
#if USE_CONV_BWD_V4R1_XDL_NHWC
if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk<
in_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in_device,
wei,
out,
nrepeat);
}
#endif
#if USE_CONV_BWD_V4R1R2_XDL_NHWC
if(algo == ConvBackwardDataAlgo::V4R1R2XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<
in_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in_device,
wei,
out,
nrepeat);
}
#endif
if(do_verification)
{
host_direct_convolution_backward_data(in_host,
wei,
out,
make_tuple(conv_stride_h, conv_stride_w),
make_tuple(conv_dilation_h, conv_dilation_w),
make_tuple(in_left_pad_h, in_left_pad_w),
make_tuple(in_right_pad_h, in_right_pad_w),
layout);
check_error(in_host, in_device);
if(do_log)
{
LogRangeAsType<float>(std::cout << "out : ", out.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "in_host : ", in_host.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "in_device: ", in_device.mData, ",") << std::endl;
}
}
}
...@@ -26,18 +26,32 @@ int main(int argc, char* argv[]) ...@@ -26,18 +26,32 @@ int main(int argc, char* argv[])
} }
const bool do_verification = atoi(argv[1]); const bool do_verification = atoi(argv[1]);
const int init_method = atoi(argv[2]); const bool do_log = atoi(argv[2]);
const bool do_log = atoi(argv[3]); const int init_method = atoi(argv[3]);
const int nrepeat = atoi(argv[4]); const int nrepeat = atoi(argv[4]);
#if 0 #if 0
constexpr index_t N = 8; constexpr index_t N = 256;
constexpr index_t C = 8; constexpr index_t C = 256;
constexpr index_t Hi = 4; constexpr index_t HI = 16;
constexpr index_t Wi = 8; constexpr index_t WI = 16;
constexpr index_t K = 256; constexpr index_t K = 256;
constexpr index_t Y = 3; constexpr index_t Y = 1;
constexpr index_t X = 3; constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
...@@ -162,9 +176,9 @@ int main(int argc, char* argv[]) ...@@ -162,9 +176,9 @@ int main(int argc, char* argv[])
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
constexpr index_t Hi = 71; constexpr index_t HI = 71;
constexpr index_t Wi = 71; constexpr index_t WI = 71;
constexpr index_t K = 128; constexpr index_t K = 256;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -430,7 +444,7 @@ int main(int argc, char* argv[]) ...@@ -430,7 +444,7 @@ int main(int argc, char* argv[])
using InRightPads = Sequence<0, 0>; using InRightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1, 14x14, stride 2 // 1x1, 14x14, stride 2
constexpr index_t N = 128; constexpr index_t N = 256;
constexpr index_t C = 1024; constexpr index_t C = 1024;
constexpr index_t Hi = 14; constexpr index_t Hi = 14;
constexpr index_t Wi = 14; constexpr index_t Wi = 14;
...@@ -445,7 +459,7 @@ int main(int argc, char* argv[]) ...@@ -445,7 +459,7 @@ int main(int argc, char* argv[])
using InRightPads = Sequence<0, 0>; using InRightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1, 14x14 // 1x1, 14x14
constexpr index_t N = 128; constexpr index_t N = 256;
constexpr index_t C = 1024; constexpr index_t C = 1024;
constexpr index_t Hi = 14; constexpr index_t Hi = 14;
constexpr index_t Wi = 14; constexpr index_t Wi = 14;
...@@ -636,6 +650,11 @@ int main(int argc, char* argv[]) ...@@ -636,6 +650,11 @@ int main(int argc, char* argv[])
using in_data_t = typename vector_type<float, in_vector_size>::type; using in_data_t = typename vector_type<float, in_vector_size>::type;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 1
using in_data_t = half_t;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using out_data_t = half_t;
#elif 0 #elif 0
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using in_data_t = typename vector_type<float, in_vector_size>::type; using in_data_t = typename vector_type<float, in_vector_size>::type;
......
...@@ -16,19 +16,31 @@ ...@@ -16,19 +16,31 @@
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NHWC 0 #define USE_CONV_FWD_V4R4_NHWC 0
#define USE_CONV_FWD_V4R5_NCHW 1 #define USE_CONV_FWD_V4R5_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
enum ConvForwardAlgo enum ConvForwardAlgo
{ {
V4R4NCHW, V4R4NCHW, // 0
V4R4NHWC, V4R4NHWC, // 1
V4R5NCHW, V4R5NCHW, // 2
V5R1NCHW V5R1NCHW, // 3
V4R4XDLNCHW, // 4
V4R4R2XDLNHWC, // 5
V4R4R3XDLNHWC, // 6
V4R4R4XDLNHWC // 7
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -97,21 +109,21 @@ int main(int argc, char* argv[]) ...@@ -97,21 +109,21 @@ int main(int argc, char* argv[])
const int nrepeat = atoi(argv[6]); const int nrepeat = atoi(argv[6]);
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 192;
constexpr index_t Hi = 17; constexpr index_t Hi = 71;
constexpr index_t Wi = 17; constexpr index_t Wi = 71;
constexpr index_t K = 128; constexpr index_t K = 256;
constexpr index_t Y = 1; constexpr index_t Y = 3;
constexpr index_t X = 7; constexpr index_t X = 3;
const index_t conv_stride_h = 1; const index_t conv_stride_h = 2;
const index_t conv_stride_w = 1; const index_t conv_stride_w = 2;
const index_t conv_dilation_h = 1; const index_t conv_dilation_h = 1;
const index_t conv_dilation_w = 1; const index_t conv_dilation_w = 1;
const index_t in_left_pad_h = 0; const index_t in_left_pad_h = 1;
const index_t in_left_pad_w = 3; const index_t in_left_pad_w = 1;
const index_t in_right_pad_h = 0; const index_t in_right_pad_h = 1;
const index_t in_right_pad_w = 3; const index_t in_right_pad_w = 1;
const index_t YEff = (Y - 1) * conv_dilation_h + 1; const index_t YEff = (Y - 1) * conv_dilation_h + 1;
const index_t XEff = (X - 1) * conv_dilation_w + 1; const index_t XEff = (X - 1) * conv_dilation_w + 1;
...@@ -120,11 +132,16 @@ int main(int argc, char* argv[]) ...@@ -120,11 +132,16 @@ int main(int argc, char* argv[])
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#endif #endif
#if 1 #if 0
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 1
constexpr index_t in_vector_size = 1;
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
#elif 1 #elif 1
constexpr index_t in_vector_size = 16; constexpr index_t in_vector_size = 16;
using in_data_t = int8_t; using in_data_t = int8_t;
...@@ -384,6 +401,114 @@ int main(int argc, char* argv[]) ...@@ -384,6 +401,114 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R4_XDL_NCHW
if(algo == ConvForwardAlgo::V4R4XDLNCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4R2_XDL_NHWC
if(algo == ConvForwardAlgo::V4R4R2XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4R3_XDL_NHWC
if(algo == ConvForwardAlgo::V4R4R3XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4R4_XDL_NHWC
if(algo == ConvForwardAlgo::V4R4R4XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
if(do_verification) if(do_verification)
{ {
host_direct_convolution(in, host_direct_convolution(in,
...@@ -397,6 +522,7 @@ int main(int argc, char* argv[]) ...@@ -397,6 +522,7 @@ int main(int argc, char* argv[])
check_error(out_host, out_device); check_error(out_host, out_device);
#if 0
if(do_log) if(do_log)
{ {
LogRange(std::cout << "in : ", in.mData, ",") << std::endl; LogRange(std::cout << "in : ", in.mData, ",") << std::endl;
...@@ -404,5 +530,6 @@ int main(int argc, char* argv[]) ...@@ -404,5 +530,6 @@ int main(int argc, char* argv[])
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl; LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl; LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
} }
#endif
} }
} }
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif
const auto descs =
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
in_n_hi_wi_c_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
I0,
I0,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto in_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<2, 0, 1>,
Sequence<0, 2, 1>,
1,
GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
6,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
decltype(in_m0_m1_m2_n_grid_iterator_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
out_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
in_m0_m1_m2_n_grid_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
}
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto descs =
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc,
wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
I0,
I0,
Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto in_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<2, 0, 1>,
Sequence<0, 2, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
#endif
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
decltype(in_m0_m1_m2_n_grid_iterator_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
true // CAccessOrderMRepeatNRepeat
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc,
out_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
in_m0_m1_m2_n_grid_iterator_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment