Commit 580e9484 authored by wangshaojie6's avatar wangshaojie6
Browse files

add skip lds pipeline

parent dbc971be
#pragma once
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0K0BN0N1N2N3K1BlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t KPerBlock = K0PerBlock * KPack;
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0K0BN0N1N2N3K1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ void MoveABlockSliceWindow()
{
a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k,
make_multi_index(0, 0, 0, K0PerBlock * KPack));
}
__device__ void ResetABlockStartWindow()
{
a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex());
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
constexpr index_t k0 = k / KPack;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(k0, n0, i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
private:
// A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0PerThread>{}, // KPerThread
Number<NRepeat>{}, // repeat
Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
};
} // namespace ck
...@@ -246,6 +246,7 @@ struct DeviceBatchedGemmGemmCShuffleXdl : public DeviceBatchedGemmGemmCShuffle<A ...@@ -246,6 +246,7 @@ struct DeviceBatchedGemmGemmCShuffleXdl : public DeviceBatchedGemmGemmCShuffle<A
VElementwiseOperation, VElementwiseOperation,
PElementwiseOperation PElementwiseOperation
OElementwiseOperation, OElementwiseOperation,
NumPrefetch,
QKMPerBlock, QKMPerBlock,
QKNPerBlock, QKNPerBlock,
QKMPerXDL, QKMPerXDL,
......
...@@ -26,6 +26,7 @@ template <index_t BlockSize, ...@@ -26,6 +26,7 @@ template <index_t BlockSize,
typename C0ElementwiseOperation, typename C0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
index_t NumGemmKPrefetchStage,
index_t M0PerBlock, index_t M0PerBlock,
index_t N0PerBlock, index_t N0PerBlock,
index_t M0PerXDL, index_t M0PerXDL,
...@@ -82,6 +83,9 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -82,6 +83,9 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<AK1>{}; static constexpr auto K1 = Number<AK1>{};
// gemm1 K1
static constexpr auto AccK1 = I4;
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
static constexpr index_t M0Waves = M0PerBlock / (M0XdlPerWave * M0PerXDL); static constexpr index_t M0Waves = M0PerBlock / (M0XdlPerWave * M0PerXDL);
static constexpr index_t N0Waves = N0PerBlock / (N0XdlPerWave * N0PerXDL); static constexpr index_t N0Waves = N0PerBlock / (N0XdlPerWave * N0PerXDL);
...@@ -97,6 +101,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -97,6 +101,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipelineSkipLds;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto max_lds_align = AK1; constexpr auto max_lds_align = AK1;
...@@ -353,7 +359,13 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -353,7 +359,13 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 = using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{})); decltype(MakeB0GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{}));
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2&)
{
}
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
...@@ -392,6 +404,9 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -392,6 +404,9 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -439,7 +454,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -439,7 +454,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
FloatAB, FloatAB,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(), b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true> true>
b_thread_1st_buf, b_thread_2nd_buf, b_thread_3rd_buf, b_thread_4th_buf; b_thread_buf[MultiK0];
const auto wave_id = GetWaveIdx(); const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]); const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
...@@ -512,173 +527,137 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -512,173 +527,137 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
// gridwise GEMM pipeline
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS
{
// Read
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_1st_buf);
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// Initialize C
c_thread_buf.Clear();
// a data write to lds
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// load 2nd a matrix data
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_2nd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// main body
if constexpr(HasMainK0BlockLoop)
{
index_t K0BlockMainLoop =
__builtin_amdgcn_readfirstlane(K0 / (MultiK0 * K0PerBlock));
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
blockwise_gemm.ResetABlockStartWindow();
block_sync_lds();
static_for<0, MultiK0, BaseMultK0>{}([&](auto) {
// 1st
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_3rd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_1st_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
// 2nd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_4th_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_2nd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
// 3rd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_1st_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_3rd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
// 4th
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_2nd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_4th_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// move a and b window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
i += 1;
} while(i < (K0BlockMainLoop - 1));
}
// tail // gridwise GEMM pipeline
{ static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
block_sync_lds(); const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
blockwise_gemm.ResetABlockStartWindow(); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
static_for<0, MultiK0, BaseMultK0>{}([&](auto i) { (a_grid_desc_k0_m_k1.GetLength(I0) * a_grid_desc_k0_m_k1.GetLength(I2)) /
// 1st KPerBlock);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, >
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), (a_grid_desc_k0_m_k1,
b_thread_3rd_buf); a_block_desc_k0_m_k1,
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, a_blockwise_copy,
b_thread_slice_copy_step); a_grid_buf,
a_block_buf,
blockwise_gemm.Run(a_block_buf, b_thread_1st_buf, c_thread_buf); a_block_slice_copy_step,
blockwise_gemm.MoveABlockSliceWindow(); b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
// 2nd b_threadwise_copy,
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_grid_buf,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_slice_copy_step,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), blockwise_gemm,
b_thread_4th_buf); c_thread_buf,
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, num_k_block_main_loop);
b_thread_slice_copy_step);
// gemm 1 O=PV
blockwise_gemm.Run(a_block_buf, b_thread_2nd_buf, c_thread_buf); // Gemm1
blockwise_gemm.MoveABlockSliceWindow(); //
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to A data type
// 3rd constexpr auto acc_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / AccK1, 0, 0);
if constexpr(i < MultiK0 - BaseMultK0) constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
{
constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
b_grid_buf, //constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, // blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
b_thread_1st_buf); constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
b_thread_slice_copy_step); constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
} constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto m3 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
blockwise_gemm.Run(a_block_buf, b_thread_3rd_buf, c_thread_buf); constexpr auto m4 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
blockwise_gemm.MoveABlockSliceWindow(); constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
// 4th // acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 to a1_thread_desc_k0_m_k1
if constexpr(i < MultiK0 - BaseMultK0) // m0_m1_m2_m3 -> k0
{ // n0_n1_n2 -> m
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, // m4 -> k1
b_grid_buf, // typical case: m0 = MRepeat, n0 = NRepeat, m4 = 4, the others are all 1
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, constexpr auto a1_thread_desc_k0_m_k1 = transform_tensor_descriptor(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
b_thread_2nd_buf); make_tuple(make_merge_transform(make_tuple(m0, m1, m2, m3)),
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, make_merge_transform(make_tuple(n0, n1, n2)),
b_thread_slice_copy_step); make_pass_through_transform(m4)),
} make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 7>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
blockwise_gemm.Run(a_block_buf, b_thread_4th_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); // A1 matrix blockwise copy
}); // actually a threadwise copy. this variant needs to support RunRead() and RunWrite()
} // TODO ANT: real blockwise copy from c_block_desc to c_thread_desc
} auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_v3r1<
Sequence<m0 * m1 * m2 * m3, n0 * n1 * n2, m4>{}, // ThreadSliceLengths
tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation
tensor_operation::element_wise::PassThrough, // DstElementwiseOperation
InMemoryDataOperationEnum::Set, // DstInMemOp
FloatGemmAcc, // SrcData
FloatAB, // DstData
a1_thread_desc_k0_m_k1, // SrcDesc
a1_thread_desc_k0_m_k1, // DstDesc
Sequence<1, 0, 2>, // SrcDimAccessOrder
Sequence<1, 0, 2>, // DstDimAccessOrder
2, // SrcVectorDim
2, // DstVectorDim
m4, // SrcScalarPerVector
m4, // DstScalarPerVector
1, // SrcScalarStrideInVector
1, // DstScalarStrideInVector
false, // ThreadTransferSrcResetCoordinateAfterRun
true, // ThreadTransferDstResetCoordinateAfterRun
NumGemmKPrefetchStage>(a1_thread_desc_k0_m_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{},
a1_thread_desc_k0_m_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b1_grid_desc_bk0_n_bk1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b1_grid_desc_bk0_n_bk1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
b_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
// reuse LDS space for gemm0's a_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared),
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
// output: register to global memory // output: register to global memory
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
__device__ void s_nop()
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
#else
__builtin_amdgcn_sched_barrier(0);
#endif
}
struct GridwiseGemmPipelineSkipLds
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop >= 2;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 2;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BThreadDesc,
typename BThreadTransfer,
typename BGridBuffer,
typename BThreadBuffer,
typename BThreadTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
index_t MultK0>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BThreadDesc& b_thread_desc,
BThreadTransfer& b_threadwise_copy,
const BGridBuffer& b_grid_buf,
BThreadBuffer& b_thread_buf[MultK0],
const BThreadTransferStep& b_thread_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// preload data to regiester and LDS
// Read
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_slice_copy_step);
static_for<0, MultK0, 1>{}([&](auto i_load_b){
b_threadwise_copy.Run(b_grid_desc,
b_grid_buf,
b_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf[i_load_b]);
s_nop();
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc,
b_thread_slice_copy_step);
});
// Initialize C
c_thread_buf.Clear();
// a data write to lds
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// main body
if constexpr(HasMainK0BlockLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.ResetABlockStartWindow();
block_sync_lds();
static_for<0, MultiK0, 1>{}([&](auto i_main) {
blockwise_gemm.Run(a_block_buf, b_thread_buf[i_main], c_thread_buf);
// 1st
b_threadwise_copy.Run(b_grid_desc,
b_grid_buf,
b_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf[i_main]);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc,
b_thread_slice_copy_step);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// move a and b window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,
a_block_slice_copy_step);
i += 1;
} while(i < (num_loop - 1));
// tail
{
block_sync_lds();
blockwise_gemm.ResetABlockStartWindow();
static_for<0, MultiK0, 1>{}([&](auto i_tail) {
blockwise_gemm.Run(a_block_buf, b_thread_buf[i_tail], c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf[i_tail], c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
}
}
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment