Unverified Commit 22d438ae authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Add gridwise GEMM pipeline (#89)

* clean up

* add mutilple thread scratch to ThreadwiseTensorSliceTransfer_v3r1

* add 2 stage prefetch

* add more sanity check into transform_tensor_descriptor

* tweak

* enabling 2 stage prefetch to exsiting gridwise gemm; tweak

* enabling 2 stage prefetch to exsiting gridwise gemm

* move gridwise gemm pipeline in class; clean up

* add some irregular tile size

* update CalculateHasMainK0BlockLoop for multi-stage-prefetch

* refactor gridwise gemm pipeline class
parent 756a7617
...@@ -307,6 +307,10 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -307,6 +307,10 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
{ {
// sanity check // sanity check
{ {
static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() &&
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{}); NewLowerDimensionOldVisibleIdss{});
......
...@@ -33,7 +33,8 @@ template <index_t BlockSize, ...@@ -33,7 +33,8 @@ template <index_t BlockSize,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct BlockwiseTensorSliceTransfer_v4r1 struct BlockwiseTensorSliceTransfer_v4r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
...@@ -86,45 +87,39 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -86,45 +87,39 @@ struct BlockwiseTensorSliceTransfer_v4r1
} }
} }
template <typename SrcBuffer, typename SrcStepHacks> template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void __device__ void RunRead(const SrcDesc& src_desc,
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) const SrcBuffer& src_buf,
{ Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
}
}
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, src_buf); threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
} }
} }
template <typename DstBuffer> template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) __device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunWrite(dst_desc, dst_buf); threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
} }
} }
template <typename SrcBuffer, typename DstBuffer> template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc, __device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{ {
RunRead(src_desc, src_buf); RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf); RunWrite(dst_desc, dst_buf, thread_scratch_id);
} }
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
...@@ -136,21 +131,6 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -136,21 +131,6 @@ struct BlockwiseTensorSliceTransfer_v4r1
} }
} }
// SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_step_hack);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
...@@ -182,7 +162,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -182,7 +162,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
SrcScalarStrideInVector, SrcScalarStrideInVector,
DstScalarStrideInVector, DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
......
#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#include "common_header.hpp"
namespace ck {
template <typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
index_t NumPrefetch,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1;
// 1-stage prefetch
template <typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1<AGridDesc,
ABlockDesc,
ABlockTransfer,
AGridBuffer,
ABlockBuffer,
ABlockTransferStep,
BGridDesc,
BBlockDesc,
BBlockTransfer,
BGridBuffer,
BBlockBuffer,
BBlockTransferStep,
BlockwiseGemm,
CThreadBuffer,
1,
HasMainLoop>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
#if 0
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#else
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#endif
}
};
// 2-stage prefetch
template <typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1<AGridDesc,
ABlockDesc,
ABlockTransfer,
AGridBuffer,
ABlockBuffer,
ABlockTransferStep,
BGridDesc,
BBlockDesc,
BBlockTransfer,
BGridBuffer,
BBlockBuffer,
BBlockTransferStep,
BlockwiseGemm,
CThreadBuffer,
2,
HasMainLoop>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// preload data into LDS
{
// Read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Write i
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// Read i+2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
// Sync
block_sync_lds();
// Gemm i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Sync
block_sync_lds();
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
// Read i+3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
// Sync
block_sync_lds();
// Gemm i+1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Sync
block_sync_lds();
i += 2;
} while(i < (num_loop - 2));
}
// tail
{
// Write num_loop - 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// Sync
block_sync_lds();
// Gemm num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Sync
block_sync_lds();
// Write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
// Sync
block_sync_lds();
// Gemm num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
#endif
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck { namespace ck {
...@@ -21,7 +22,7 @@ template <typename GridwiseGemm, ...@@ -21,7 +22,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainK0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -40,7 +41,7 @@ __global__ void ...@@ -40,7 +41,7 @@ __global__ void
{ {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared, p_shared,
...@@ -90,7 +91,8 @@ template <index_t BlockSize, ...@@ -90,7 +91,8 @@ template <index_t BlockSize,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector,
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -194,6 +196,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -194,6 +196,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01 // check M01, N01
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -219,9 +240,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -219,9 +240,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return grid_size; return grid_size;
} }
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
...@@ -316,7 +338,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -316,7 +338,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -381,7 +403,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -381,7 +403,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true,
NumPrefetch>(
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -411,7 +434,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -411,7 +434,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true,
NumPrefetch>(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -455,51 +479,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -455,51 +479,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// preload data into LDS // gridwise GEMM pipeline
{ const auto gridwise_gemm_pipeline =
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); remove_cvref_t<decltype(a_grid_buf)>,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); remove_cvref_t<decltype(a_block_buf)>,
} remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
// Initialize C remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
c_thread_buf.Clear(); remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
// main body remove_cvref_t<decltype(b_block_buf)>,
if constexpr(HasMainKBlockLoop) remove_cvref_t<decltype(b_block_slice_copy_step)>,
{ remove_cvref_t<decltype(blockwise_gemm)>,
index_t k0_block_data_begin = 0; remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
do HasMainK0BlockLoop>{};
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_block_desc_k0_m_k1,
a_blockwise_copy,
block_sync_lds(); a_grid_buf,
a_block_buf,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); b_block_desc_k0_n_k1,
b_blockwise_copy,
block_sync_lds(); b_grid_buf,
b_block_buf,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_block_slice_copy_step,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); blockwise_gemm,
c_thread_buf,
k0_block_data_begin += K0PerBlock; K0BlockMainLoop);
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// output: register to global memory // output: register to global memory
{ {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck { namespace ck {
...@@ -22,7 +23,7 @@ template <typename GridwiseGemm, ...@@ -22,7 +23,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainK0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -42,7 +43,7 @@ __global__ void ...@@ -42,7 +43,7 @@ __global__ void
{ {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -95,7 +96,8 @@ template < ...@@ -95,7 +96,8 @@ template <
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -228,6 +230,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -228,6 +230,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01 // check M01, N01
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -253,9 +274,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -253,9 +274,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return grid_size; return grid_size;
} }
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
...@@ -329,7 +351,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -329,7 +351,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -397,7 +419,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -397,7 +419,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true,
NumPrefetch>(
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -427,7 +450,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -427,7 +450,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true,
NumPrefetch>(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -471,51 +495,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -471,51 +495,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// preload data into LDS // gridwise GEMM pipeline
{ const auto gridwise_gemm_pipeline =
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); remove_cvref_t<decltype(a_grid_buf)>,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); remove_cvref_t<decltype(a_block_buf)>,
} remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
// Initialize C remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
c_thread_buf.Clear(); remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
// main body remove_cvref_t<decltype(b_block_buf)>,
if constexpr(HasMainKBlockLoop) remove_cvref_t<decltype(b_block_slice_copy_step)>,
{ remove_cvref_t<decltype(blockwise_gemm)>,
index_t k0_block_data_begin = 0; remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
do HasMainK0BlockLoop>{};
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_block_desc_k0_m_k1,
a_blockwise_copy,
block_sync_lds(); a_grid_buf,
a_block_buf,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); b_block_desc_k0_n_k1,
b_blockwise_copy,
block_sync_lds(); b_grid_buf,
b_block_buf,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_block_slice_copy_step,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); blockwise_gemm,
c_thread_buf,
k0_block_data_begin += K0PerBlock; K0BlockMainLoop);
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// shuffle C and write out // shuffle C and write out
{ {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r2.hpp" #include "blockwise_tensor_slice_transfer_v6r2.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck { namespace ck {
...@@ -23,7 +24,7 @@ template <typename GridwiseGemm, ...@@ -23,7 +24,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainK0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -46,7 +47,7 @@ __global__ void ...@@ -46,7 +47,7 @@ __global__ void
{ {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -102,7 +103,8 @@ template < ...@@ -102,7 +103,8 @@ template <
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -235,6 +237,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -235,6 +237,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01 // check M01, N01
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -260,9 +281,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -260,9 +281,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
return grid_size; return grid_size;
} }
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
...@@ -342,7 +364,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -342,7 +364,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -417,7 +439,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -417,7 +439,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true,
NumPrefetch>(
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -447,7 +470,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -447,7 +470,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true,
NumPrefetch>(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -491,51 +515,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -491,51 +515,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// preload data into LDS // gridwise GEMM pipeline
{ const auto gridwise_gemm_pipeline =
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); remove_cvref_t<decltype(a_grid_buf)>,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); remove_cvref_t<decltype(a_block_buf)>,
} remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
// Initialize C remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
c_thread_buf.Clear(); remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
// main body remove_cvref_t<decltype(b_block_buf)>,
if constexpr(HasMainKBlockLoop) remove_cvref_t<decltype(b_block_slice_copy_step)>,
{ remove_cvref_t<decltype(blockwise_gemm)>,
index_t k0_block_data_begin = 0; remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
do HasMainK0BlockLoop>{};
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_block_desc_k0_m_k1,
a_blockwise_copy,
block_sync_lds(); a_grid_buf,
a_block_buf,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); b_block_desc_k0_n_k1,
b_blockwise_copy,
block_sync_lds(); b_grid_buf,
b_block_buf,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_block_slice_copy_step,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); blockwise_gemm,
c_thread_buf,
k0_block_data_begin += K0PerBlock; K0BlockMainLoop);
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// shuffle C and write out // shuffle C and write out
{ {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r3.hpp" #include "blockwise_tensor_slice_transfer_v6r3.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck { namespace ck {
...@@ -24,7 +25,7 @@ template <typename GridwiseGemm, ...@@ -24,7 +25,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainK0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -50,7 +51,7 @@ __global__ void ...@@ -50,7 +51,7 @@ __global__ void
{ {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -109,7 +110,8 @@ template < ...@@ -109,7 +110,8 @@ template <
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -242,6 +244,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -242,6 +244,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01 // check M01, N01
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -267,9 +288,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -267,9 +288,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return grid_size; return grid_size;
} }
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
...@@ -354,7 +376,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -354,7 +376,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -510,51 +532,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -510,51 +532,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// preload data into LDS // gridwise GEMM pipeline
{ const auto gridwise_gemm_pipeline =
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); remove_cvref_t<decltype(a_grid_buf)>,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); remove_cvref_t<decltype(a_block_buf)>,
} remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
// Initialize C remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
c_thread_buf.Clear(); remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
// main body remove_cvref_t<decltype(b_block_buf)>,
if constexpr(HasMainKBlockLoop) remove_cvref_t<decltype(b_block_slice_copy_step)>,
{ remove_cvref_t<decltype(blockwise_gemm)>,
index_t k0_block_data_begin = 0; remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
do HasMainK0BlockLoop>{};
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_block_desc_k0_m_k1,
a_blockwise_copy,
block_sync_lds(); a_grid_buf,
a_block_buf,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); b_block_desc_k0_n_k1,
b_blockwise_copy,
block_sync_lds(); b_grid_buf,
b_block_buf,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_block_slice_copy_step,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); blockwise_gemm,
c_thread_buf,
k0_block_data_begin += K0PerBlock; K0BlockMainLoop);
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// shuffle C and write out // shuffle C and write out
{ {
......
...@@ -64,9 +64,10 @@ template <typename SliceLengths, ...@@ -64,9 +64,10 @@ template <typename SliceLengths,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to // RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation // save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to // RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation // save addr computation
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v3r1 struct ThreadwiseTensorSliceTransfer_v3r1
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -78,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -78,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
...@@ -102,9 +105,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -102,9 +105,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcBuffer, typename SrcStepHacks> template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void __device__ void RunRead(const SrcDesc& src_desc,
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -114,9 +118,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -114,9 +118,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
...@@ -138,8 +139,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -138,8 +139,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(src_desc, forward_step_idx);
src_desc, forward_step_idx, src_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -152,8 +152,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -152,8 +152,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(src_desc, backward_step_idx);
src_desc, backward_step_idx, src_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -215,7 +214,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -215,7 +214,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}); });
// copy data from src_vector_container into src_thread_scratch_ // copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_.template SetAsType<src_vector_t>( src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]); src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
...@@ -263,12 +263,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -263,12 +263,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
} }
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() template <index_t ThreadScratchId>
__device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{ {
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here // convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]); dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple[thread_scratch_id][idx]);
}); });
#else #else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
...@@ -318,7 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -318,7 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
const auto src_vector_refs = generate_tie( const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& { [&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim // i increment corresponds to movement in DstVectorDim
return src_thread_scratch_.GetVectorTypeReference( return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector); data_idx_seq + i * dst_scalar_step_in_vector);
}, },
Number<num_src_vector>{}); Number<num_src_vector>{});
...@@ -342,19 +345,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -342,19 +345,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{ {
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here // convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]); dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
}); });
} }
#endif #endif
} }
template <typename DstBuffer, typename DstStepHacks> template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void __device__ void RunWrite(const DstDesc& dst_desc,
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
// if there is transpose, it's done here // if there is transpose, it's done here
// TODO move this elsewhere // TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch(); TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -364,9 +369,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -364,9 +369,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// src scalar per access on each dim // src scalar per access on each dim
// TODO: don't use this // TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
...@@ -388,8 +390,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -388,8 +390,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(dst_desc, forward_step_idx);
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -402,8 +403,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -402,8 +403,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(dst_desc, backward_step_idx);
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -515,39 +515,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -515,39 +515,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
} }
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
constexpr index_t ntransform_src = remove_cvref_t<SrcDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_step_hacks);
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
// TODO: why need remove_cvref_t ?
constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_step_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
{ {
constexpr auto I0 = Number<0>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
...@@ -606,8 +575,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -606,8 +575,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
{ {
constexpr auto I0 = Number<0>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
...@@ -679,25 +646,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -679,25 +646,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
move_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx) const Index& dst_slice_origin_step_idx)
...@@ -815,19 +763,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -815,19 +763,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr, using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
SrcData, SrcData,
SrcScalarPerVector, SrcScalarPerVector,
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true> true>;
src_thread_scratch_;
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr, using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
DstData, DstData,
DstScalarPerVector, DstScalarPerVector,
decltype(dst_thread_scratch_desc_), decltype(dst_thread_scratch_desc_),
true> true>;
dst_thread_scratch_;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_;
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
......
...@@ -26,6 +26,7 @@ set(DEVICE_GEMM_INSTANCE_SOURCE ...@@ -26,6 +26,7 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
......
...@@ -52,7 +52,8 @@ template <typename ADataType, ...@@ -52,7 +52,8 @@ template <typename ADataType,
ck::index_t BBlockTransferDstScalarPerVector_K1, ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1>
struct DeviceGemmXdl struct DeviceGemmXdl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
...@@ -218,7 +219,8 @@ struct DeviceGemmXdl ...@@ -218,7 +219,8 @@ struct DeviceGemmXdl
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector,
NumPrefetch>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -494,7 +496,12 @@ struct DeviceGemmXdl ...@@ -494,7 +496,12 @@ struct DeviceGemmXdl
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -4,9 +4,7 @@ ...@@ -4,9 +4,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp"
#include "device_gemm.hpp" #include "device_gemm.hpp"
#include "device_gemm_xdl.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -54,7 +52,8 @@ template < ...@@ -54,7 +52,8 @@ template <
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumPrefetch = 1>
struct DeviceGemmXdl_C_Shuffle struct DeviceGemmXdl_C_Shuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
...@@ -174,7 +173,8 @@ struct DeviceGemmXdl_C_Shuffle ...@@ -174,7 +173,8 @@ struct DeviceGemmXdl_C_Shuffle
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
CBlockTransferScalarPerVector_NWaveNPerXdl>; CBlockTransferScalarPerVector_NWaveNPerXdl,
NumPrefetch>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
......
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#####################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num|
//#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch|
//#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -26,10 +26,10 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -26,10 +26,10 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
...@@ -46,10 +46,25 @@ using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = ...@@ -46,10 +46,25 @@ using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances =
// clang-format on // clang-format on
>; >;
// irregular tile size
using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances =
std::tuple<
// clang-format off
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
...@@ -11,13 +11,23 @@ ...@@ -11,13 +11,23 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
...@@ -31,45 +41,56 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -31,45 +41,56 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< #if 1
ADataType, // ADataType using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
BDataType, // BDataType //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
CDataType, // CDataType //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
AccDataType, // AccDataType //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| |
ALayout, // ALayout //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
BLayout, // BLayout // [256, 128, 4, 8], 1 stage, 2 occupancy
CLayout, // CLayout < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>;
AElementOp, // AElementwiseOperation #elif 0
BElementOp, // BElementwiseOperation using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
CElementOp, // CElementwiseOperation //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
256, // BlockSize //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
256, // MPerBlock //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| |
128, // NPerBlock //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
4, // K0PerBlock // [128, 144, 8, 8], 1 stage, 1 occupancy, bounded by LDS size
8, // K1 // 99 TFlops, 120 blocks (1024x2160x3840)
32, // MPerXDL // 99 TFlops, 960 blocks (4096x4320x3840)
32, // NPerXDL < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>;
4, // MXdlPerWave // [128, 144, 4, 8], 1 stage, 2 occupancy,
2, // NXdlPerWave // 92 TFlops, 120 blocks (1024x2160x3840)
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 // 120 TFlops, 240 blocks (1024x4320x3840)
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder // 128 TFlops, 960 blocks (4096x4320x3840)
S<1, 0, 2>, // ABlockTransferSrcAccessOrder // < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>;
2, // ABlockTransferSrcVectorDim // [ 64, 144, 8, 8], 1 stage, 2 occupancy/
8, // ABlockTransferSrcScalarPerVector // 96 TFlops, 240 blocks (1024x2160x3840)
8, // ABlockTransferDstScalarPerVector_K1 // 96 TFlops, 480 blocks (1024x4320x3840)
true, // ABlockLdsAddExtraM // 99 TFlops,1920 blocks (4096x4320x3840)
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 // < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>;
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder // [ 64, 144, 8, 8], 2 stage, 2 occupancy
S<1, 0, 2>, // BBlockTransferSrcAccessOrder // 93 TFlops
2, // BBlockTransferSrcVectorDim // < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>;
8, // BBlockTransferSrcScalarPerVector // [ 64, 144, 4, 8], 1 stage, 2 occupancy
8, // BBlockTransferDstScalarPerVector_K1 // 87 TFlops
true, // BBlockLdsAddExtraN // < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>;
1, // CShuffleMXdlPerWavePerShuffle // [ 64, 144, 4, 8], 2 stage, 2 occupancy
1, // CShuffleNXdlPerWavePerShuffle // 85 TFlops
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl // < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>;
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl #elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// [128, 144, 8, 8], 1 stage, 1 occupancy, bounded by LDS size
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 1, 9, S<1, 1, 8, 1, 9, 2>, 8, 1>;
#endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
...@@ -198,8 +219,8 @@ int main(int argc, char* argv[]) ...@@ -198,8 +219,8 @@ int main(int argc, char* argv[])
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl; << gemm.GetTypeString() << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" #include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "reference_conv_fwd.hpp" #include "reference_conv_fwd.hpp"
#include "convolution_utility.hpp" #include "convolution_utility.hpp"
...@@ -35,9 +36,8 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -35,9 +36,8 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
// clang-format off
using DeviceConvFwdInstance = ck::tensor_operation::device:: using DeviceConvFwdInstance = ck::tensor_operation::device::
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
InDataType, // InDataType InDataType, // InDataType
WeiDataType, // WeiDataType WeiDataType, // WeiDataType
OutDataType, // OutDataType OutDataType, // OutDataType
...@@ -69,11 +69,8 @@ using DeviceConvFwdInstance = ck::tensor_operation::device:: ...@@ -69,11 +69,8 @@ using DeviceConvFwdInstance = ck::tensor_operation::device::
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1 8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 7, // CThreadTransferSrcDstVectorDim
1, // CShuffleNXdlPerWavePerShuffle 1>; // CThreadTransferDstScalarPerVector
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using ReferenceConvFwdInstance = ck::tensor_operation::host:: using ReferenceConvFwdInstance = ck::tensor_operation::host::
ReferenceConvFwd<InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp>; ReferenceConvFwd<InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp>;
......
#pragma once #pragma once
#include <iomanip>
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -30,6 +31,9 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<De ...@@ -30,6 +31,9 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<De
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
...@@ -225,6 +229,9 @@ void profile_gemm_impl(int do_verification, ...@@ -225,6 +229,9 @@ void profile_gemm_impl(int do_verification,
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
...@@ -293,8 +300,8 @@ void profile_gemm_impl(int do_verification, ...@@ -293,8 +300,8 @@ void profile_gemm_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< " GB/s, " << gemm_name << std::endl; << gb_per_sec << " GB/s, " << gemm_name << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment