Commit 8929bde2 authored by wangshaojie6's avatar wangshaojie6
Browse files

add second blockwisegemm with A is in VGPR. WIP: second gemm pipeline

parent 580e9484
#pragma once
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename BK0NK1BlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_skip_a_lds
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t KPerBlock = K0PerBlock * KPack;
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ void MoveABlockSliceWindow()
{
a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k,
make_multi_index(0, 0, 0, K0PerBlock * KPack));
}
__device__ void ResetABlockStartWindow()
{
a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex());
}
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_thread_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
constexpr index_t k0 = k / KPack;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(k0, m0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
private:
// B[M0, M1, M2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// A[K0PerThread, M0, KPack]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0PerThread>{}, // KPerThread
Number<MRepeat>{}, // repeat
Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
} // namespace ck
......@@ -20,7 +20,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_skip_b_lds
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......
......@@ -201,6 +201,15 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
}
};
template <>
struct UnaryTypeConvert<ck::half_t, float>
{
__host__ __device__ void operator()(ck::half_t& y, float& x) const
{
y = ck::type_convert<ck::half_t, float>(x);
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -85,6 +85,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
// gemm1 K1
static constexpr auto AccK1 = I4;
static constexpr auto Gemm1K0PerBlock = Number<KPerBlock / AccK1>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t M0Waves = M0PerBlock / (M0XdlPerWave * M0PerXDL);
......@@ -101,7 +102,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipelineSkipLds;
using GridwiseGemmPipe0 = GridwiseGemmPipelineSkipBLds;
using GridwiseGemmPipe1 = GridwiseGemmPipelineAInVgpr;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
......@@ -358,23 +360,22 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
decltype(MakeB0GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{}));
using B0GridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
decltype(MakeB0GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(B0GridDesc_K0_N_K1{}));
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2&)
{
}
using TypeConvertFp32ToFp16Functor =
ck::tensor_operation::element_wise::UnaryTypeConvert<ck::half_t, float>;
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b0_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const B0GridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b0_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const B1GridDesc_K0_N_K1& b1_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -383,12 +384,14 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto Gemm0K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// divide block work by [M, N]
const auto block_work_idx =
......@@ -474,13 +477,13 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
wave_k_n_id[I0],
wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t",
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b0_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
#endif
auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
decltype(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3),
decltype(b0_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
Sequence<I1,
I1,
......@@ -495,16 +498,14 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b0_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_multi_index(
0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// GEMM0 definition
// c_mtx += b_mtx * a_mtx
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1<
BlockSize,
......@@ -530,27 +531,27 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
// gridwise GEMM 0 pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe0>);
const auto gridwise_gemm_pipeline0 = GridwiseGemmPipe0{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_k0_m_k1.GetLength(I0) * a_grid_desc_k0_m_k1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop,
>
gridwise_gemm_pipeline0.template Run<HasMainKBlockLoop,
MultiK0>
(a_grid_desc_k0_m_k1,
a_block_desc_k0_m_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b0_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_threadwise_copy,
b_grid_buf,
b_thread_buf,
b_thread_slice_copy_step,
blockwise_gemm,
c_thread_buf,
......@@ -589,35 +590,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 7>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// A1 matrix blockwise copy
// actually a threadwise copy. this variant needs to support RunRead() and RunWrite()
// TODO ANT: real blockwise copy from c_block_desc to c_thread_desc
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_v3r1<
Sequence<m0 * m1 * m2 * m3, n0 * n1 * n2, m4>{}, // ThreadSliceLengths
tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation
tensor_operation::element_wise::PassThrough, // DstElementwiseOperation
InMemoryDataOperationEnum::Set, // DstInMemOp
FloatGemmAcc, // SrcData
FloatAB, // DstData
a1_thread_desc_k0_m_k1, // SrcDesc
a1_thread_desc_k0_m_k1, // DstDesc
Sequence<1, 0, 2>, // SrcDimAccessOrder
Sequence<1, 0, 2>, // DstDimAccessOrder
2, // SrcVectorDim
2, // DstVectorDim
m4, // SrcScalarPerVector
m4, // DstScalarPerVector
1, // SrcScalarStrideInVector
1, // DstScalarStrideInVector
false, // ThreadTransferSrcResetCoordinateAfterRun
true, // ThreadTransferDstResetCoordinateAfterRun
NumGemmKPrefetchStage>(a1_thread_desc_k0_m_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{},
a1_thread_desc_k0_m_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
......@@ -649,14 +622,62 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
// reuse LDS space for gemm0's a_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared),
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a1_thread_slice_copy_step = make_multi_index(Gemm1K0PerBlock, 0, 0, 0, 0, 0, 0, 0);
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1K0PerBlock, 0, 0);
// GEMM1 definition
// c_mtx += a_mtx * b_mtx
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm1 = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_skip_a_lds<
BlockSize,
FloatAB,
FloatAcc,
decltype(b1_block_desc_bk0_n_bk1),
MPerBlock,
NPerBlock,
Gemm1K0PerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
K1>{};
auto c1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
// gridwise GEMM 0 pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe1>);
const auto gridwise_gemm_pipeline1 = GridwiseGemmPipe1{};
const index_t num_k_block_main_loop_1 = __builtin_amdgcn_readfirstlane(
(a_grid_desc_k0_m_k1.GetLength(I0) * a_grid_desc_k0_m_k1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline1.template Run<TypeConvertFp32ToFp16Functor,
MultiK0>
(a1_thread_desc_k0_m_k1,
c_thread_buf,
a1_thread_buf,
a1_thread_slice_copy_step,
b1_grid_desc_bk0_n_bk1,
b1_block_desc_bk0_n_bk1,
b1_blockwise_copy,
b1_grid_buf,
b1_block_buf,
b1_block_slice_copy_step,
blockwise_gemm,
c1_thread_buf,
num_k_block_main_loop);
// output: register to global memory
......
......@@ -18,7 +18,7 @@ __device__ void s_nop()
#endif
}
struct GridwiseGemmPipelineSkipLds
struct GridwiseGemmPipelineSkipBLds
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
......@@ -32,6 +32,7 @@ struct GridwiseGemmPipelineSkipLds
}
template <bool HasMainLoop,
index_t MultK0,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
......@@ -45,8 +46,7 @@ struct GridwiseGemmPipelineSkipLds
typename BThreadBuffer,
typename BThreadTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
index_t MultK0>
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
......@@ -57,7 +57,7 @@ struct GridwiseGemmPipelineSkipLds
const BThreadDesc& b_thread_desc,
BThreadTransfer& b_threadwise_copy,
const BGridBuffer& b_grid_buf,
BThreadBuffer& b_thread_buf[MultK0],
BThreadBuffer* b_thread_buf,
const BThreadTransferStep& b_thread_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
......@@ -142,4 +142,117 @@ struct GridwiseGemmPipelineSkipLds
}
};
struct GridwiseGemmPipelineAInVgpr
{
static constexpr I0 = Number<0>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop >= 2;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 2;
}
template<typename TypeConvertFp32ToFp16Functor, typename AThreadDesc, index_t i_k0>
__device__ static void ConvertCopy(const AThreadDesc& a_thread_desc,
const AccThreadBuffer& acc_thread_buf,
AThreadBuffer& a_thread_buf)
{
constexpr auto i_k0_num = Number<i_k0>{};
static_for<0, n0, 1>{}([&](auto n) {
static_for<0, m4, 1>{}([&](auto m) {
constexpr auto acc_offset = a_thread_desc.CalculateOffset(make_tuple(i_k0_num, n, I0, I0, I0, I0, m, I0));
constexpr auto a_offset = a_thread_desc.CalculateOffset(make_tuple(I0, n, I0, I0, I0, I0, m, I0));
TypeConvertFp32ToFp16Functor(a1_thread_buf(Number<a_offset>{}), acc_thread_buf(Number<acc_offset>{}));
});
});
}
template <typename TypeConvertFp32ToFp16Functor,
index_t MultK0,
typename AThreadDesc,
typename AccThreadBuffer,
typename AThreadBuffer,
typename AThreadTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AThreadDesc& a_thread_desc,
const AccThreadBuffer& acc_thread_buf,
AThreadBuffer& a_thread_buf,
const AThreadTransferStep& a_thread_transfer_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf)
{
static_for<0, MultiK0, 1>{}([&](auto i_k0){
});
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// a data write to lds
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// main body
if constexpr(HasMainK0BlockLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
blockwise_gemm.Run(a_thread_buf, b_block_buf, c_thread_buf);
block_sync_lds();
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
i += 1;
} while(i < (num_loop - 2));
// tail
{
block_sync_lds();
// GEMM num_loop - 2
blockwise_gemm.Run(a_thread_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
block_sync_lds();
// GEMM num_loop - 1
blockwise_gemm.Run(a_thread_buf, b_block_buf, c_thread_buf);
}
}
}
};
} // namespace ck
......@@ -1180,6 +1180,11 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
}
__device__ void SetSrcCoord(const Index& src_ref_idx)
{
src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx);
}
private:
SrcCoord src_ref_coord_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment