Commit 40016f20 authored by Jing Zhang's avatar Jing Zhang
Browse files

add m/n repeats

parent 8c84c0b1
......@@ -30,10 +30,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t WaveSize = 64;
static constexpr index_t MPerBlock = ABlockDesc{}.GetLength(I1); // A is transposed
static constexpr index_t NPerBlock = BBlockDesc{}.GetLength(I1);
static constexpr index_t MWaves = MPerBlock / MPerWave;
static constexpr index_t NWaves = NPerBlock / NPerWave;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// static constexpr index_t MPerBlock = M0 * M1; // A is transposed
// static constexpr index_t NPerBlock = N0 * N1;
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); }
......@@ -59,13 +69,13 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base;
return make_tuple(k_offset, m_offset);
return make_tuple(k_offset, 0, m_offset);
}
else
{
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, m_offset);
return make_tuple(k_offset, 0, m_offset);
}
}
......@@ -81,26 +91,30 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base;
return make_tuple(k_offset, n_offset);
return make_tuple(k_offset, 0, n_offset);
}
else
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, n_offset);
return make_tuple(k_offset, 0, n_offset);
}
}
template <index_t AStride = MPerWave, index_t BStride = NPerWave>
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t blk_i)
__device__ static CIndex CalculateCThreadOriginDataIndex(const index_t m_repeat_id,
const index_t n_repeat_id,
const index_t blk_i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(blk_i);
const index_t row = (waveId / NWaves) * AStride + thread_mtx_on_blk.row;
const index_t col = (waveId % NWaves) * BStride + thread_mtx_on_blk.col;
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
const index_t row = m_repeat_id * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row;
const index_t col = n_repeat_id * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col;
return CIndex{row, col};
}
......@@ -115,8 +129,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
static_assert(MPerWave * MWaves == MPerBlock, "GemmMWaves * MPerWave != M");
static_assert(NPerWave * NWaves == NPerBlock, "GemmNWaves * NPerWave != N");
// static_assert(MPerWave * MWaves == MPerBlock, "GemmMWaves * MPerWave != M");
// static_assert(NPerWave * NWaves == NPerBlock, "GemmNWaves * NPerWave != N");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
......@@ -136,39 +150,78 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_for<0, KPerBlock, KPerWave>{}([&](auto k) {
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0),
make_tuple(k, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0),
make_tuple(I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0),
make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0),
make_tuple(I0, I1, I0),
b_thread_buf);
xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf);
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
});
}
private:
// A[K, M]
static constexpr auto a_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerWave>{}, Number<1>{}));
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerWave>{}, Number<MRepeat>{}, I1));
// B[K, N]
static constexpr auto b_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerWave>{}, Number<1>{}));
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerWave>{}, Number<NRepeat>{}, I1));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<KPerWave, 1>,
Sequence<0, 1>,
1,
Sequence<KPerWave, 1, 1>,
Sequence<0, 1, 2>,
2,
1,
1>;
......@@ -176,9 +229,9 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
FloatB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<KPerWave, 1>,
Sequence<0, 1>,
1,
Sequence<KPerWave, 1, 1>,
Sequence<0, 1, 2>,
2,
1,
1>;
......
......@@ -278,41 +278,42 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0, "wrong!");
// constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
// constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
// a_k_m_block_desc,
// make_tuple(
// make_pass_through_transform(Number<KPerBlock>{}),
// make_unmerge_transform(make_tuple(
// Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
// constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
// b_k_n_block_desc,
// make_tuple(
// make_pass_through_transform(Number<KPerBlock>{}),
// make_unmerge_transform(make_tuple(
// Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
// constexpr auto c_m0_m1_n0_n1_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
// Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm = BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
FloatAB,
decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc),
MPerWave,
NPerWave,
KPerWave>{};
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
FloatAB,
decltype(a_k_m0_m1_block_desc),
decltype(b_k_n0_n1_block_desc),
MPerWave,
NPerWave,
KPerWave>{};
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
......@@ -483,50 +484,56 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// static_assert(BlkSize == 16 && NumBlks == 4, "");
// force unrolling the output loop to get ride of scratches
static_for<0, NumBlks, 1>{}([&](auto i) {
StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_;
static_for<0, BlkSize, 1>{}([&](auto j) {
c_thread_buf_(j) =
c_thread_buf.template AsType<float>()[Number<i * BlkSize + j>{}];
static_for<0, MRepeat, 1>{}([&](auto m_i) {
static_for<0, NRepeat, 1>{}([&](auto n_i) {
// force unrolling the output loop to get ride of scratches
static_for<0, NumBlks, 1>{}([&](auto i) {
StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_;
static_for<0, BlkSize, 1>{}([&](auto j) {
c_thread_buf_(j) = c_thread_buf.template AsType<
float>()[Number<m_i*(NRepeat * BlkSize * NumBlks) +
n_i*(BlkSize * NumBlks) + i * BlkSize + j>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(m_i, n_i, i);
const index_t k_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t b_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder,
3, // CThreadTransferSrcDstVectorDim,
1, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % M2,
b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
});
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(i);
const index_t k_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t b_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder,
3, // CThreadTransferSrcDstVectorDim,
1, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % M2,
b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
});
}
}
......
......@@ -599,6 +599,34 @@ struct XdlopsGemm
});
}
template <class ADesc,
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
__device__ void Run2(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<data_type, float>::value || is_same<data_type, half_t>::value ||
is_same<data_type, ushort>::value,
"base data_type must be float, half, ushort!");
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_multi_index(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_multi_index(k, n0, 0));
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_multi_index(m0, n0));
mfma_type.template run<MPerXdlops, NPerXdlops>(
p_a_wave[Number<a_offset>{}],
p_b_wave[Number<b_offset>{}],
p_c_thread.template AsType<float16_t>()(Number<c_offset>{}));
});
}
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i)
{
const index_t xdlops_i = i / GetNumBlksPerXdlops();
......
......@@ -278,11 +278,9 @@ struct intrin_mfma_f32_32x32x2f32;
template <>
struct intrin_mfma_f32_32x32x2f32<32, 32>
{
__device__ static void
Run(const float& reg_a, const float& reg_b, vector_type<float, 16>& reg_c)
__device__ static void Run(const float& reg_a, const float& reg_b, float16_t& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_c = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(reg_a, reg_b, reg_c, 0, 0, 0);
}
};
......
......@@ -104,21 +104,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#else
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 32;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPerWave = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment