Commit 0df8ed88 authored by Anthony Chang's avatar Anthony Chang
Browse files

roll out inter-wave loop scheduler to c-shuffle gemm variants

will gradually roll out to other applicable device ops when occasional reg spill is resolved
parent 9464c5ef
...@@ -248,10 +248,150 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -248,10 +248,150 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
static constexpr index_t KPerInnerLoop = #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
math::max(KPerThread / CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS, KPack); /// Note: we need to explicitly enable CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING because some
/// features are not yet available in latest ROCm release. If
/// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING is disabled but LoopScheduler is Interwave, fall back
/// to default loop scheduler to avoid compilation error
using Base::a_block_desc_m0_m1_m2_k;
using Base::A_K1;
using Base::b_block_desc_n0_n1_n2_k;
using Base::B_K1;
using Base::c_thread_buf_;
using Base::c_thread_desc_;
using Base::CalculateAThreadOriginDataIndex;
using Base::CalculateBThreadOriginDataIndex;
using Base::I0;
using Base::I1;
using Base::KPerThread;
using Base::xdlops_gemm;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
// 2-wave optimized blockwise gemm // 2-wave optimized blockwise gemm
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -348,7 +488,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -348,7 +488,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}); });
} }
private: protected:
// A[M0, M1, M2, KPerInnerLoop] // A[M0, M1, M2, KPerInnerLoop]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{})); make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
...@@ -377,99 +517,51 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -377,99 +517,51 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
B_K1, B_K1,
B_K1>; B_K1>;
#else // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
private:
// A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING #endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
};
// C[M, N, NumRegXdlops] template <index_t BlockSize,
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( typename FloatAB,
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops())); typename FloatAcc,
typename AK0MK1BlockDesc,
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; typename BK0NK1BlockDesc,
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}; };
} // namespace ck } // namespace ck
...@@ -14,6 +14,8 @@ namespace ck { ...@@ -14,6 +14,8 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation failures.
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -62,7 +64,8 @@ template <typename ALayout, ...@@ -62,7 +64,8 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock> index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = LoopScheduler::Interwave>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation, struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -422,7 +425,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -422,7 +425,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>; CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopSched>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
......
...@@ -14,6 +14,8 @@ namespace ck { ...@@ -14,6 +14,8 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation failures.
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -54,7 +56,8 @@ template <typename ALayout, ...@@ -54,7 +56,8 @@ template <typename ALayout,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock> index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Interwave>
struct DeviceGemm_Xdl_CShuffle struct DeviceGemm_Xdl_CShuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
...@@ -375,7 +378,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -375,7 +378,8 @@ struct DeviceGemm_Xdl_CShuffle
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
......
...@@ -79,9 +79,7 @@ struct GridwiseGemmPipeline_v1<1> ...@@ -79,9 +79,7 @@ struct GridwiseGemmPipeline_v1<1>
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
#if !CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
block_sync_lds(); block_sync_lds();
#endif // !CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -250,4 +248,112 @@ struct GridwiseGemmPipeline_v1<2> ...@@ -250,4 +248,112 @@ struct GridwiseGemmPipeline_v1<2>
} }
}; };
template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1;
template <>
struct GridwiseGemmPipelineInterwave_v1<1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void
Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// block_sync_lds(); // moved into blockwise_gemm
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
template <index_t NumPrefetch,
bool HasMainLoop>
constexpr auto GridwiseGemmPipeline_v1_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
return GridwiseGemmPipelineInterwave_v1<NumPrefetch>{};
}
}
} // namespace ck } // namespace ck
...@@ -134,7 +134,8 @@ template <typename FloatAB, ...@@ -134,7 +134,8 @@ template <typename FloatAB,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock> index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default>
struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -473,17 +474,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -473,17 +474,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockSize,
FloatAB, FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
MPerXdl, MPerXdl,
NPerXdl, NPerXdl,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack>{}; KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -502,25 +504,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -502,25 +504,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, gridwise_gemm_pipeline.Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
......
...@@ -107,7 +107,8 @@ template <typename FloatAB, ...@@ -107,7 +107,8 @@ template <typename FloatAB,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock> index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -416,17 +417,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -416,17 +417,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockSize,
FloatAB, FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
MPerXdl, MPerXdl,
NPerXdl, NPerXdl,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack>{}; KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -445,25 +447,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -445,25 +447,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, gridwise_gemm_pipeline.Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
......
...@@ -7,6 +7,12 @@ ...@@ -7,6 +7,12 @@
namespace ck { namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
enum struct MfmaInstr enum struct MfmaInstr
{ {
mfma_f32_32x32x1xf32 = 0, mfma_f32_32x32x1xf32 = 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment