Commit 40016f20 authored by Jing Zhang's avatar Jing Zhang
Browse files

add m/n repeats

parent 8c84c0b1
...@@ -30,10 +30,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -30,10 +30,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
static constexpr index_t MPerBlock = ABlockDesc{}.GetLength(I1); // A is transposed static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BBlockDesc{}.GetLength(I1); static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
static constexpr index_t MWaves = MPerBlock / MPerWave;
static constexpr index_t NWaves = NPerBlock / NPerWave; static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// static constexpr index_t MPerBlock = M0 * M1; // A is transposed
// static constexpr index_t NPerBlock = N0 * N1;
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); } __device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); }
...@@ -59,13 +69,13 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -59,13 +69,13 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{ {
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base; const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base;
return make_tuple(k_offset, m_offset); return make_tuple(k_offset, 0, m_offset);
} }
else else
{ {
const index_t m_offset = waveId_m * MPerWave + laneId; const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0; const index_t k_offset = 0;
return make_tuple(k_offset, m_offset); return make_tuple(k_offset, 0, m_offset);
} }
} }
...@@ -81,26 +91,30 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -81,26 +91,30 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{ {
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base; const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base;
return make_tuple(k_offset, n_offset); return make_tuple(k_offset, 0, n_offset);
} }
else else
{ {
const index_t n_offset = waveId_n * NPerWave + laneId; const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0; const index_t k_offset = 0;
return make_tuple(k_offset, n_offset); return make_tuple(k_offset, 0, n_offset);
} }
} }
template <index_t AStride = MPerWave, index_t BStride = NPerWave> __device__ static CIndex CalculateCThreadOriginDataIndex(const index_t m_repeat_id,
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t blk_i) const index_t n_repeat_id,
const index_t blk_i)
{ {
const index_t waveId = get_thread_local_1d_id() / WaveSize; const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(blk_i); const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(blk_i);
const index_t row = (waveId / NWaves) * AStride + thread_mtx_on_blk.row; const index_t waveId_m = waveId / NWaves;
const index_t col = (waveId % NWaves) * BStride + thread_mtx_on_blk.col; const index_t waveId_n = waveId % NWaves;
const index_t row = m_repeat_id * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row;
const index_t col = n_repeat_id * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col;
return CIndex{row, col}; return CIndex{row, col};
} }
...@@ -115,8 +129,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -115,8 +129,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
static_assert(MPerWave * MWaves == MPerBlock, "GemmMWaves * MPerWave != M"); // static_assert(MPerWave * MWaves == MPerBlock, "GemmMWaves * MPerWave != M");
static_assert(NPerWave * NWaves == NPerBlock, "GemmNWaves * NPerWave != N"); // static_assert(NPerWave * NWaves == NPerBlock, "GemmNWaves * NPerWave != N");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
...@@ -136,39 +150,78 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -136,39 +150,78 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_for<0, KPerBlock, KPerWave>{}([&](auto k) { static_for<0, KPerBlock, KPerWave>{}([&](auto k) {
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0), make_tuple(k, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
a_thread_buf); a_thread_buf);
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0), make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0), make_tuple(I0, I1, I0),
b_thread_buf); b_thread_buf);
xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
}); });
} }
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerWave>{}, Number<1>{})); make_tuple(Number<KPerWave>{}, Number<MRepeat>{}, I1));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerWave>{}, Number<1>{})); make_tuple(Number<KPerWave>{}, Number<NRepeat>{}, I1));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<KPerWave, 1>, Sequence<KPerWave, 1, 1>,
Sequence<0, 1>, Sequence<0, 1, 2>,
1, 2,
1, 1,
1>; 1>;
...@@ -176,9 +229,9 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -176,9 +229,9 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
FloatB, FloatB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<KPerWave, 1>, Sequence<KPerWave, 1, 1>,
Sequence<0, 1>, Sequence<0, 1, 2>,
1, 2,
1, 1,
1>; 1>;
......
...@@ -278,41 +278,42 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -278,41 +278,42 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
static_assert(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0, "wrong!"); constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
// constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
// constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
// constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( "wrong!");
// a_k_m_block_desc,
// make_tuple( constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
// make_pass_through_transform(Number<KPerBlock>{}), a_k_m_block_desc,
// make_unmerge_transform(make_tuple( make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
// Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))), make_unmerge_transform(
// make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{}))),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
// constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
// b_k_n_block_desc, constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
// make_tuple( b_k_n_block_desc,
// make_pass_through_transform(Number<KPerBlock>{}), make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
// make_unmerge_transform(make_tuple( make_unmerge_transform(
// Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))), make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{}))),
// make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
// constexpr auto c_m0_m1_n0_n1_thread_desc = // constexpr auto c_m0_m1_n0_n1_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( // make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
// Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); // Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm = BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize, const auto blockwise_gemm =
FloatAB, BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB, FloatAB,
decltype(a_k_m_block_desc), FloatAB,
decltype(b_k_n_block_desc), decltype(a_k_m0_m1_block_desc),
MPerWave, decltype(b_k_n0_n1_block_desc),
NPerWave, MPerWave,
KPerWave>{}; NPerWave,
KPerWave>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -483,50 +484,56 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -483,50 +484,56 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// static_assert(BlkSize == 16 && NumBlks == 4, ""); // static_assert(BlkSize == 16 && NumBlks == 4, "");
// force unrolling the output loop to get ride of scratches static_for<0, MRepeat, 1>{}([&](auto m_i) {
static_for<0, NumBlks, 1>{}([&](auto i) { static_for<0, NRepeat, 1>{}([&](auto n_i) {
StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_; // force unrolling the output loop to get ride of scratches
static_for<0, NumBlks, 1>{}([&](auto i) {
static_for<0, BlkSize, 1>{}([&](auto j) { StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_;
c_thread_buf_(j) =
c_thread_buf.template AsType<float>()[Number<i * BlkSize + j>{}]; static_for<0, BlkSize, 1>{}([&](auto j) {
c_thread_buf_(j) = c_thread_buf.template AsType<
float>()[Number<m_i*(NRepeat * BlkSize * NumBlks) +
n_i*(BlkSize * NumBlks) + i * BlkSize + j>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(m_i, n_i, i);
const index_t k_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t b_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder,
3, // CThreadTransferSrcDstVectorDim,
1, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % M2,
b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
});
}); });
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(i);
const index_t k_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t b_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder,
3, // CThreadTransferSrcDstVectorDim,
1, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % M2,
b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}); });
} }
} }
......
...@@ -599,6 +599,34 @@ struct XdlopsGemm ...@@ -599,6 +599,34 @@ struct XdlopsGemm
}); });
} }
template <class ADesc,
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
__device__ void Run2(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<data_type, float>::value || is_same<data_type, half_t>::value ||
is_same<data_type, ushort>::value,
"base data_type must be float, half, ushort!");
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_multi_index(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_multi_index(k, n0, 0));
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_multi_index(m0, n0));
mfma_type.template run<MPerXdlops, NPerXdlops>(
p_a_wave[Number<a_offset>{}],
p_b_wave[Number<b_offset>{}],
p_c_thread.template AsType<float16_t>()(Number<c_offset>{}));
});
}
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i) __device__ static MatrixIndex GetBeginOfThreadBlk(index_t i)
{ {
const index_t xdlops_i = i / GetNumBlksPerXdlops(); const index_t xdlops_i = i / GetNumBlksPerXdlops();
......
...@@ -278,11 +278,9 @@ struct intrin_mfma_f32_32x32x2f32; ...@@ -278,11 +278,9 @@ struct intrin_mfma_f32_32x32x2f32;
template <> template <>
struct intrin_mfma_f32_32x32x2f32<32, 32> struct intrin_mfma_f32_32x32x2f32<32, 32>
{ {
__device__ static void __device__ static void Run(const float& reg_a, const float& reg_b, float16_t& reg_c)
Run(const float& reg_a, const float& reg_b, vector_type<float, 16>& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32( reg_c = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(reg_a, reg_b, reg_c, 0, 0, 0);
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
......
...@@ -104,21 +104,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -104,21 +104,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#else #else
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32; constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 32; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPerWave = 2; constexpr index_t GemmKPerWave = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment