Commit 56de337f authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 41b920e2 687d2b7e
......@@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
const Block2ETileMap&)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
//{
// return false;
//}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
......
......@@ -17,18 +17,21 @@ enum struct PipelineVersion
v2,
// v3 is only used in the Stream-K implementation.
v4,
weight_only,
};
template <PipelineVersion PipelineVer,
index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default>
LoopScheduler LoopSched = LoopScheduler::Default,
bool AEnableLds = true,
bool BEnableLds = true>
constexpr auto GridwiseGemmPipeline_Selector()
{
if constexpr(PipelineVer == PipelineVersion::v1)
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
return GridwiseGemmPipeline_v1<NumPrefetch, AEnableLds, BEnableLds>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
......@@ -43,6 +46,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
return GridwiseGemmPipeline_v4<NumPrefetch>{};
}
else if constexpr(PipelineVer == PipelineVersion::weight_only)
{
return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{};
}
else
{
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
......
......@@ -9,12 +9,12 @@
namespace ck {
template <index_t NumPrefetch>
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1;
// 1-stage prefetch
template <>
struct GridwiseGemmPipeline_v1<1>
struct GridwiseGemmPipeline_v1<1, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -108,7 +108,7 @@ struct GridwiseGemmPipeline_v1<1>
// 2-stage prefetch
template <>
struct GridwiseGemmPipeline_v1<2>
struct GridwiseGemmPipeline_v1<2, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -254,6 +254,406 @@ struct GridwiseGemmPipeline_v1<2>
}
};
template <>
struct GridwiseGemmPipeline_v1<1, false, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_block_buf = a_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <>
struct GridwiseGemmPipeline_v1<1, true, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <>
struct GridwiseGemmPipeline_v1<1, false, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_block_buf = a_block_buf_switch;
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1_WeightOnly;
template <>
struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleGridBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleGridBuffer& scale_grid_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Global Prefetch Stage 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Scale read once
b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// Dequantization fused in blockwise_copy
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1;
......@@ -349,7 +749,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template <>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true>
{
};
......@@ -359,7 +759,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
return GridwiseGemmPipeline_v1<NumPrefetch, true, true>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
......
......@@ -93,7 +93,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage, true, true>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......
......@@ -18,11 +18,11 @@
namespace ck {
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename ADataType,
typename BDataType,
typename CDataType,
typename AGridDesc,
typename BGridDesc,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -33,31 +33,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_wmma(
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
const AGridDesc a_grid_desc,
const BGridDesc b_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
a_grid_desc,
b_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
......@@ -67,8 +63,8 @@ __global__ void
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = a_grid_desc;
ignore = b_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
......@@ -78,21 +74,21 @@ __global__ void
}
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename FloatCShuffle,
typename FloatC,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename AGridDesc,
typename BGridDesc,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t KPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t K1Value,
......@@ -105,6 +101,7 @@ template <index_t BlockSize,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool AEnableLds,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
......@@ -113,6 +110,7 @@ template <index_t BlockSize,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BEnableLds,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
......@@ -121,7 +119,7 @@ template <index_t BlockSize,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
struct GridwiseGemm_Wmma
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -132,103 +130,277 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
// FIX ME: To be deprecated
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
using GridwiseGemmPipe =
remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopSched,
AEnableLds,
BEnableLds>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
// Describe how data store to (LDS/VGPR) buffer from Global memory
__host__ __device__ static constexpr auto MakeABlockDescriptor()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
constexpr auto a_block_desc = [&]() {
if constexpr(AEnableLds)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
// K0->M->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
constexpr auto max_lds_align = K1;
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
}
}();
return a_block_desc_k0perblock_mperblock_k1;
return a_block_desc;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
__host__ __device__ static constexpr auto MakeBBlockDescriptor()
{
constexpr auto max_lds_align = K1;
constexpr auto b_block_desc = [&]() {
if constexpr(BEnableLds)
{
// K0->N->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
make_tuple(Number<KWmmaPerblock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
}
}();
return b_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{
constexpr auto a_block_copy_step = [&]() {
if constexpr(AEnableLds)
{
constexpr auto K0PerBlock = KPerBlock / K1;
return make_multi_index(K0PerBlock, 0, 0);
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
}
}();
return b_block_desc_k0perblock_nperblock_k1;
return a_block_copy_step;
}
__host__ __device__ static constexpr auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
__host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
constexpr auto b_block_copy_step = [&]() {
if constexpr(BEnableLds)
{
constexpr auto K0PerBlock = KPerBlock / K1;
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
return make_multi_index(K0PerBlock, 0, 0);
}
else
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
}
}();
return b_block_copy_step;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
// Describe how data read from (LDS/VGPR) buffer
template <typename ABlockDesc_>
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0perblock_mperblock_k1 =
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
// Err: merge transform cause non-constexpr issue
// return transform_tensor_descriptor(
// ABlockDesc_{},
// make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
// make_pass_through_transform(Number<MRepeat>{}),
// make_pass_through_transform(I1),
// make_pass_through_transform(I1),
// make_pass_through_transform(Number<A_K1>{})),
// make_tuple(Sequence<0, 3>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
// Sequence<4>{}));
// Workaround, Freeze transform
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}));
}
}();
constexpr auto max_lds_align = K1;
return a_wave_desc;
}
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
template <typename BBlockDesc_>
__host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
{
constexpr auto b_wave_desc = [&]() {
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}));
}
}();
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return b_wave_desc;
}
return (a_block_space_size_aligned * sizeof(FloatA) +
b_block_space_size_aligned * sizeof(FloatB));
__host__ __device__ static constexpr auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
......@@ -237,23 +409,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
(NPerBlock % (NRepeat * NPerWmma)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto GetAProblemsizeMK = [&]() {
if constexpr(AEnableLds)
{
return make_tuple(a_grid_desc.GetLength(I1),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
}
else
{
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I5),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
}
};
const auto GetBProblemsizeNK = [&]() {
if constexpr(BEnableLds)
{
return make_tuple(b_grid_desc.GetLength(I1),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
}
else
{
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I5),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
}
};
const auto M = GetAProblemsizeMK()[I0];
const auto N = GetBProblemsizeNK()[I0];
const auto K = GetAProblemsizeMK()[I1];
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
K == GetBProblemsizeNK()[I1]))
{
printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n",
GetAProblemsizeMK()[I0],
GetAProblemsizeMK()[I1],
GetBProblemsizeNK()[I0],
GetBProblemsizeNK()[I1],
c_grid_desc_m_n.GetLength(I0),
c_grid_desc_m_n.GetLength(I1));
printf("GridwiseOp err: ProblemSize check");
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
printf("GridwiseOp err: ProblemSize division");
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K0 / K0PerBlock;
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
printf("GridwiseOp err: Pipeline not support this k_loop");
return false;
}
......@@ -265,8 +480,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB))
if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
{
return false;
}
......@@ -275,7 +490,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / (K0PerBlock * K1);
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
......@@ -313,13 +528,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto max_lds_align = K1;
static constexpr auto a_block_space_size_aligned =
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
max_lds_align)
: 0;
static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
.GetElementSpaceSize();
static constexpr auto c_shuffle_block_space_offset = 0;
static constexpr auto lds_size =
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType));
};
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
......@@ -331,9 +577,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
// Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
p_b_grid, b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......@@ -351,24 +597,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
const auto K = [&](){
if constexpr(AEnableLds){
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
* a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
}
}();
constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = MakeBBlockDescriptor();
auto a_block_trait = [&](){
// A matrix blockwise copy
if constexpr(AEnableLds)
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared),
SharedMemTrait::a_block_space_size_aligned);
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
/* typename SrcElementwiseOperation, */ AElementwiseOperation,
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA,
/* typename DstData, */ FloatA,
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
/* typename SrcData, */ ADataType,
/* typename DstData, */ ADataType,
/* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
......@@ -378,99 +641,197 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* index_t SrcScalarStrideInVector, */ 1,
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
a_grid_desc_k0_m_k1,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
NumGemmKPrefetchStage>(
a_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0perblock_mperblock_k1,
a_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
FloatB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0perblock_nperblock_k1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0perblock_nperblock_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
return make_tuple(a_block_buf, a_blockwise_copy);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto a_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
decltype(a_grid_desc),
decltype(a_block_desc),
Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc,
make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(a_block_buf, a_blockwise_copy);
}
};
auto b_block_trait = [&](){
if constexpr(BEnableLds)
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
SharedMemTrait::b_block_space_size_aligned);
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc),
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
return make_tuple(b_block_buf, b_blockwise_copy);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType,
decltype(b_grid_desc),
decltype(b_block_desc),
Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc,
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(b_block_buf, b_blockwise_copy);
}
};
auto a_block_buf = a_block_trait()[I0];
auto a_blockwise_copy = a_block_trait()[I1];
auto b_block_buf = b_block_trait()[I0];
auto b_blockwise_copy = b_block_trait()[I1];
/*******************************************************************************/
// GEMM
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatA,
FloatB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0perblock_nperblock_k1),
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>{};
BlockwiseGemmWMMA<BlockSize,
ADataType,
BDataType,
AccDataType,
decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeBWaveDescriptor(b_block_desc)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack,
AEnableLds,
BEnableLds>{};
// Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
/*******************************************************************************/
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
/*******************************************************************************/
// Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
// gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
a_block_desc_k0perblock_mperblock_k1,
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_block_desc_k0perblock_nperblock_k1,
b_grid_desc,
b_block_desc,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
KBlockMainLoop);
/*******************************************************************************/
// write out to C, implement shuffle
{
// C mapping in single thread.
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// This API Provide All dimension (size) you need
// C mapping in single block
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
......@@ -485,8 +846,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
SharedMemTrait::c_shuffle_block_space_size);
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
......@@ -532,8 +893,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatCShuffle,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
ck::tensor_operation::element_wise::PassThrough,
......@@ -571,8 +932,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......@@ -636,6 +997,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
......
......@@ -1333,4 +1333,139 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation element_op_;
};
// Specilized for WMMA
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
// SrcA: From specific thread buffer hold by This RowLane on This Row
// SrcB: From specific thread buffer hold by This RowLane on The other Row
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
uint32_t LowEightRowlaneIdx,
uint32_t HighEightRowLaneIdx,
bool IntraRowSwizzlePerm,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
ignore = src_idx;
}
template <typename SrcSliceOriginIdx,
typename DstSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! SliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
"wrong! Buffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
// scalar per access on each dim
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// src_desc error, non constexpr, caused by merge transform
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v_this_row, v_theother_row;
// int type temp value due to intrinsic requirement
int temp = 0;
// apply element-wise operation
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// apply intra-row permute.
if constexpr(IntraRowSwizzlePerm)
{
temp = __builtin_amdgcn_permlane16(
temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
v_this_row = type_convert_sp<SrcData>(temp);
}
// apply inter-row permute.
temp = __builtin_amdgcn_permlanex16(temp,
type_convert_sp<int>(v_this_row),
LowEightRowlaneIdx,
HighEightRowLaneIdx,
1,
0);
v_theother_row = type_convert_sp<SrcData>(temp);
if(get_thread_local_1d_id() % 32 < 16)
{
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
type_convert_sp<DstData>(v_theother_row);
}
else
{
// apply type convert
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
type_convert_sp<DstData>(v_this_row);
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_theother_row);
}
});
});
}
ElementwiseOperation element_op_{};
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace ck {
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst_idle
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
// 5. Dequantization happened between read and write.
template <typename SliceLengths,
typename ScaleSliceLengths,
typename SrcElementwiseOperation,
typename ScaleElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename SrcData,
typename ScaleData,
typename DstData,
typename SrcDesc,
typename ScaleDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t ScaleScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t ScaleScalarStrideInVector,
index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v3r1_dequant
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const ScaleDesc& scale_desc,
const Index& scale_slice_origin,
const ScaleElementwiseOperation& scale_element_op,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
scale_coord_(make_tensor_coordinate(scale_desc, scale_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op),
scale_element_op_(scale_element_op),
dst_element_op_(dst_element_op)
{
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc,
const Index& scale_slice_origin_idx)
{
scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
template <typename ScaleBuffer>
__device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
{
static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename ScaleBuffer::type>, remove_cvref_t<ScaleData>>::value,
"wrong! ScaleBuffer and ScaleData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scale_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, ScaleScalarPerVector>{}, Number<nDim>{});
constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access;
constexpr auto scale_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_scale_access_lengths =
container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order);
// make forward steps
const auto scale_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(scale_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto scale_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(scale_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_scale_access_lengths)>{}([&](auto ordered_scale_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_scale_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate scale data index
constexpr auto scale_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i]
: ordered_scale_access_lengths[i] - 1 -
ordered_scale_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) *
scale_scalar_per_access;
}();
constexpr auto scale_data_idx_seq =
generate_sequence_v2([&](auto i) { return Number<scale_data_idx[i]>{}; },
Number<scale_data_idx.Size()>{});
const bool is_scale_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
scale_desc, scale_coord_);
using scale_vector_type = vector_type_maker_t<ScaleData, ScaleScalarPerVector>;
using scale_vector_t = typename scale_vector_type::type;
// copy data from scale_buf into scale_vector_container
auto scale_vector_container = scale_vector_type{
scale_buf.template Get<scale_vector_t>(scale_coord_.GetOffset(), is_scale_valid)};
// copy data from scale_vector_container into scale_thread_scratch_
scale_thread_scratch_.template SetAsType<scale_vector_t>(
scale_data_idx_seq, scale_vector_container.template AsType<scale_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) =
ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move scale coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(scale_desc,
scale_coord_,
scale_forward_steps[scale_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(scale_desc,
scale_coord_,
scale_backward_steps[scale_dim_access_order[i]]);
}
}
});
});
// don't need to move scale coordinate back to slice origin
/*
if constexpr(SrcResetCoordinateAfterRun)
{
const auto scale_reset_step =
make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep());
move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step);
}
*/
}
template <index_t ThreadScratchId>
__device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<SrcData>>::value &&
is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert(SrcVectorDim != DstVectorDim, "wrong");
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst_idle<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
constexpr auto data_idx = access_idx * scalar_per_access;
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim
return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector);
},
Number<num_src_vector>{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto dst_vector_refs = generate_tie(
[&](auto i) -> dst_vector_t& {
// i increment corresponds to movement in SrcVectorDim
return dst_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * src_scalar_step_in_vector);
},
Number<num_dst_vector>{});
// do data transpose
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
// Do fast numeric convert
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst_idle<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using src_converted_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using src_converted_vector_t = typename src_converted_vector_type::type;
// Vector-wise type convert
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
auto src_vector_container = src_vector_type{
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<src_vector_t>(
access_idx)};
auto src_converted_vector_container =
src_converted_vector_type{fast_numeric_converter(src_vector_container)};
src_converted_thread_scratch_.template SetAsType<src_converted_vector_t>(
access_idx,
src_converted_vector_container.template AsType<src_converted_vector_t>()[I0]);
});
// Element-scale operation, expect packed multiplication
static_ford<SliceLengths>{}([&](auto idx) {
DstData dst_v;
constexpr auto scale_idx = Sequence<I0, idx.At(1), I0>{};
// printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<const uint16_t*>(&scale_thread_scratch_[scale_idx])));
src_element_op_(dst_v,
src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]);
dst_thread_scratch_(idx) = dst_v;
});
#endif
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// if there is transpose, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
// copy data from dst_thread_scratch_ into dst_vector_container
auto dst_vector_container = dst_vector_type{
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
DstData dst_v;
// apply DstElementwiseOperation
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
dst_vector_container.template AsType<DstData>()(i) = dst_v;
});
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_;
}();
return reset_src_data_step;
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetScaleThreadScratchDescriptor()
{
constexpr auto scale_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, ScaleScalarPerVector>{}, Number<nDim>{});
constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access;
constexpr auto scale_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(scale_access_lengths), Number<ScaleScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(scale_access_lengths_and_vector_length[i],
scale_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(scale_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto scale_thread_scratch_desc_ =
decltype(GetScaleThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
/*
template <bool kLastDim>
struct ScaleThreadScratchDesc{};
*/
// Registers, contain raw data loaded from global buffer
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcData,
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>;
// Registers, contain fast converted data
using SrcThreadConvertedScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>;
// Registers, contain scale data
using ScaleThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
ScaleData,
ScaleScalarPerVector,
decltype(scale_thread_scratch_desc_),
true>;
// Registers, contain dequantized data
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
DstScalarPerVector,
decltype(dst_thread_scratch_desc_),
true>;
using FastTypeConverter = tensor_operation::element_wise::
FastNumericArrayConverter<SrcData, DstData, SrcScalarPerVector>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
SrcThreadConvertedScratch src_converted_thread_scratch_;
ScaleThreadScratch scale_thread_scratch_;
DstThreadScratch dst_thread_scratch_;
FastTypeConverter fast_numeric_converter;
SrcCoord src_coord_;
ScaleCoord scale_coord_;
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
const ScaleElementwiseOperation scale_element_op_;
const DstElementwiseOperation dst_element_op_;
};
} // namespace ck
......@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
......@@ -100,7 +101,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
......@@ -129,6 +130,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
......@@ -136,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
......@@ -153,7 +155,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
}
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
WaveSize,
......@@ -166,6 +167,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t acc_pack_number = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
......@@ -173,28 +175,22 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
WaveSize,
......@@ -207,6 +203,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t acc_pack_number = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
......@@ -214,7 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
......@@ -227,17 +224,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
{
if constexpr(wave_size == 32)
{
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
}
}
};
#endif
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
WaveSize,
......@@ -250,6 +245,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
......@@ -257,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
......@@ -346,7 +342,7 @@ struct WmmaSelector
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size ==
selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register");
}
......@@ -358,7 +354,8 @@ template <typename src_type_a,
index_t MPerWmma,
index_t NPerWmma,
index_t KPack,
bool TransposeC = false>
bool TransposeC = false,
bool AssemblyBackend = false>
struct WmmaGemm
{
static constexpr auto I0 = Number<0>{};
......@@ -369,14 +366,14 @@ struct WmmaGemm
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
using CIndex3D = MultiIndex<3>;
__host__ __device__ constexpr WmmaGemm()
{
static_assert(NPerWmma == 16 && MPerWmma == 16,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
}
// WMMA output supporting C = A * B
......@@ -421,9 +418,49 @@ struct WmmaGemm
Sequence<5>{}));
}
// Transposed WMMA Output C' = B' * A'
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{}),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
}
__device__ static constexpr index_t GetRegSizePerWmma()
{
return wmma_instr.num_acc_vgprs_per_wave;
return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
}
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
......@@ -449,14 +486,16 @@ struct WmmaGemm
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
}
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
}
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
......@@ -477,12 +516,12 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
return GetSwizzledLaneIdLow();
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
return GetLaneIdUnderSubGroup();
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
}
__device__ static CIndex GetBeginOfThreadBlk()
......@@ -493,6 +532,14 @@ struct WmmaGemm
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
}
__device__ static CIndex3D GetBeginOfThreadBlk3D()
{
index_t n_offset = GetLaneIdUnderSubGroup();
index_t m_offset = GetSubGroupId();
return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
}
static constexpr auto wmma =
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma;
......@@ -500,7 +547,10 @@ struct WmmaGemm
__host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
return make_tuple(I1,
I1,
Number<wmma_instr.num_acc_vgprs_per_wave>{},
Number<wmma_instr.acc_pack_number>{});
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace ck {
namespace tensor_operation {
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
device::TensorSpecialization TensorSpec>
__host__ __device__ static auto
MakeGridDescriptorPair(const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_strides_vec)
{
// if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
// gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
// {
// throw std::runtime_error("wrong! dimension must match input lengths");
// }
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto gs_ms_ns_lengths =
to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto gs_ms_ns_strides =
to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimG + NumDimM, NumDimG + NumDimM + NumDimN, 1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
if constexpr(TensorSpec == device::TensorSpecialization::Packed)
{
auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(G, M, N),
make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(M, N),
make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
}
else
{
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
// Note: This does not require padding as it only provides G offset calculation. Technically
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
// G_M_N
const auto grid_desc_g_mraw_nraw =
transform_tensor_descriptor(grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto c_ms_ns_lengths = to_tuple(
gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
}
}
template <typename NumDims_G_M_N_K_O, // Sequence<>
typename PerBlock_M_N_K_O, // Sequence<>
device::GemmSpecialization GemmSpec,
device::TensorSpecialization ASpec,
device::TensorSpecialization B0Spec,
device::TensorSpecialization B1Spec,
device::TensorSpecialization CSpec>
struct TransformBatchedContractionContractionToBatchedGemmGemm_Wmma
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
static constexpr auto matrix_padder =
device::GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, OPerBlock};
//
// A
//
__host__ __device__ static auto MakeAGridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
a_gs_ms_ks_strides_vec);
}
// TODO: rename to G_MRaw_KRaw
__host__ __device__ static auto MakeAGridDescriptor_G_M_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
}
__host__ __device__ static auto MakeAGridDescriptor_M_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return matrix_padder.PadADescriptor_M_K(
MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
}
template <typename AGridDesc_M_K, typename Number>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename AGridDesc_M_K,
typename WmmaK,
typename MRepeat,
typename MWaves,
typename MPerWmma,
typename AK1>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(
const AGridDesc_M_K& a_grid_desc_m_k,
const WmmaK&,
const MRepeat&,
const MWaves&,
const MPerWmma&,
const AK1&)
{
const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock;
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AKWmma = K / WmmaK{};
constexpr auto AKRow = 2;
constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{};
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(
make_tuple(AKWmma, Number<AK0PerWmma>{}, Number<AKRow>{}, AK1{})),
make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
//
// B (alias of B0)
//
__host__ __device__ static auto MakeB0GridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
b0_gs_ns_ks_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
__host__ __device__ static auto MakeB0GridDescriptor_G_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{
return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
}
__host__ __device__ static auto MakeB0GridDescriptor_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return matrix_padder.PadBDescriptor_N_K(
MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
}
template <typename BGridDesc_N_K, typename Number>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BGridDesc_L_K,
typename WmmaK,
typename LRepeat,
typename LWaves,
typename LPerWmma,
typename BK1>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(
const BGridDesc_L_K& b_grid_desc_l_k,
const WmmaK&,
const LRepeat&,
const LWaves&,
const LPerWmma&,
const BK1&)
{
const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
const auto K = b_grid_desc_l_k.GetLength(I1);
const auto BKWmma = K / WmmaK{};
constexpr auto BKRow = 2;
constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{};
return transform_tensor_descriptor(
b_grid_desc_l_k,
make_tuple(make_unmerge_transform(
make_tuple(BKWmma, Number<BK0PerWmma>{}, Number<BKRow>{}, BK1{})),
make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
//
// B1
//
__host__ __device__ static auto MakeB1GridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
b1_gs_os_ns_strides_vec);
}
// TODO: rename to G_NRaw_KRaw
__host__ __device__ static auto MakeB1GridDescriptor_G_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{
return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
}
__host__ __device__ static auto MakeB1GridDescriptor_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return matrix_padder.PadB1Descriptor_N_K(
MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
}
template <typename B1GridDesc_N_K, typename Number>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
{
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BGridDesc_N_L,
typename WmmaL,
typename NRepeat,
typename NWaves,
typename NPerWmma,
typename BL1>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(
const BGridDesc_N_L& b_grid_desc_n_l,
const WmmaL&,
const NRepeat&,
const NWaves&,
const NPerWmma&,
const BL1&)
{
const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
const auto L = b_grid_desc_n_l.GetLength(I1);
const auto BLWmma = L / WmmaL{};
constexpr auto BLRow = 2;
constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{};
return transform_tensor_descriptor(
b_grid_desc_n_l,
make_tuple(make_unmerge_transform(
make_tuple(BLWmma, Number<BL0PerWmma>{}, Number<BLRow>{}, BL1{})),
make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
//
// C
//
__host__ __device__ static auto MakeCGridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
c_gs_ms_os_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
__host__ __device__ static auto MakeCGridDescriptor_G_M_N(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{
return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
}
__host__ __device__ static auto MakeCGridDescriptor_M_N(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{
return matrix_padder.PadCDescriptor_M_N(
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
}
};
} // namespace tensor_operation
} // namespace ck
......@@ -417,7 +417,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using r_t = typename vector_type<T, N>::type;
......
......@@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
"0"(c0),
"1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
#endif
}
......@@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
"2"(c2),
"3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
#endif
}
......@@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3);
}
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{
#if defined(__gfx11__)
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
#else
ignore = a;
ignore = b;
ignore = c;
#endif
}
} // namespace ck
#endif
......@@ -133,6 +133,13 @@ struct scalar_type<int8_t>
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<uint8_t>
{
using type = uint8_t;
static constexpr index_t vector_size = 1;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct scalar_type<int4_t>
......@@ -1037,6 +1044,14 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8
// i8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
template <typename T>
struct NumericLimits
......
......@@ -99,6 +99,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
template <>
inline __host__ __device__ constexpr int type_convert_sp<int, float>(float x)
{
union
{
float fp32;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr float type_convert_sp<float, int>(int x)
{
union
{
int int32;
float fp32;
} u = {x};
return u.fp32;
}
template <>
inline __host__ __device__ constexpr int type_convert_sp<int, half_t>(half_t x)
{
union
{
half_t fp16;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr half_t type_convert_sp<half_t, int>(int x)
{
union
{
int int32;
half_t fp16;
} u = {x};
return u.fp16;
}
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
......
......@@ -5,8 +5,11 @@
#include "ck/wrapper/utils/layout_utils.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
/**
* \brief Layout wrapper that performs the tensor descriptor logic.
......@@ -19,6 +22,8 @@ namespace wrapper {
template <typename Shape, typename UnrolledDescriptorType>
struct Layout
{
// Disable from doxygen docs generation
/// @cond INTERNAL
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -246,6 +251,7 @@ struct Layout
using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
/// @endcond
public:
using LayoutShape = Shape;
......@@ -457,6 +463,8 @@ struct Layout
return unrolled_descriptor_;
}
// Disable from doxygen docs generation
/// @cond INTERNAL
private:
// All dimensions are unrolled
UnrolledDescriptorType unrolled_descriptor_;
......@@ -469,6 +477,7 @@ struct Layout
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const Shape shape_;
/// @endcond
};
} // namespace wrapper
......
......@@ -12,8 +12,11 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
......
......@@ -9,9 +9,14 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace {
namespace detail {
/**
......@@ -45,6 +50,7 @@ __device__ constexpr auto GetBlockDescriptor()
} // namespace detail
} // namespace
/// @endcond
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
......
......@@ -7,9 +7,14 @@
#include "utils/tensor_partition.hpp"
#include "utils/layout_utils.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace {
namespace detail {
/**
......@@ -189,6 +194,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
}
} // namespace detail
} // namespace
/// @endcond
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
......@@ -394,6 +400,8 @@ struct Tensor
}
private:
// Disable from doxygen docs generation
/// @cond INTERNAL
using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
ElementType,
ElementSpaceSize,
......@@ -428,6 +436,7 @@ struct Tensor
// tensor descriptor (thus all it's transforms) and is linear (1D).
// We store base_offset_ to avoid multiple recalculations.
index_t base_offset_;
/// @endcond
};
} // namespace wrapper
......
......@@ -5,8 +5,11 @@
#include "ck/ck.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
/**
* \brief Traits for blockwise gemm xdl.
......
......@@ -5,8 +5,11 @@
#include "ck/ck.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......
......@@ -17,11 +17,14 @@
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace ck {
namespace wrapper {
/// @endcond
// Disable from doxygen docs generation
/// @cond
/// @cond INTERNAL
// forward declaration
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment