Commit de9f5bed authored by Jing Zhang's avatar Jing Zhang
Browse files

add kpack into xldops_gemm and blockwise_gemm

parent 776721ab
...@@ -15,7 +15,7 @@ template <index_t BlockSize, ...@@ -15,7 +15,7 @@ template <index_t BlockSize,
class BBlockDesc, class BBlockDesc,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave> index_t KPack>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{ {
...@@ -26,8 +26,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -26,8 +26,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPerWave>{};
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
...@@ -36,6 +34,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -36,6 +34,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPack>{};
static constexpr index_t MWaves = M1 / MPerWave; static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave; static constexpr index_t NWaves = N1 / NPerWave;
...@@ -59,14 +59,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -59,14 +59,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base; const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, m_offset); return make_tuple(k_offset, 0, m_offset, 0);
} }
else else
{ {
const index_t m_offset = waveId_m * MPerWave + laneId; const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0; const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset); return make_tuple(k_offset, 0, m_offset, 0);
} }
} }
...@@ -81,14 +81,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -81,14 +81,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base; const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, n_offset); return make_tuple(k_offset, 0, n_offset, 0);
} }
else else
{ {
const index_t n_offset = waveId_n * NPerWave + laneId; const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0; const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset); return make_tuple(k_offset, 0, n_offset, 0);
} }
} }
...@@ -120,8 +120,19 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -120,8 +120,19 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
"wrong! KPack dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!");
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -136,21 +147,21 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -136,21 +147,21 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_for<0, KPerBlock, KPerWave>{}([&](auto k) { static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A // read A
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
// read B // read B
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -168,11 +179,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -168,11 +179,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerWave>{}, Number<MRepeat>{}, I1)); make_tuple(I1, Number<MRepeat>{}, I1, Number<KPack>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerWave>{}, Number<NRepeat>{}, I1)); make_tuple(I1, Number<NRepeat>{}, I1, Number<KPack>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
...@@ -181,20 +192,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -181,20 +192,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
FloatA, FloatA,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<KPerWave, MRepeat, 1>, Sequence<1, MRepeat, 1, KPack>,
Sequence<0, 1, 2>, Sequence<0, 1, 2, 3>,
2, 3,
1, 1, // KPack,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB, FloatB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<KPerWave, NRepeat, 1>, Sequence<1, NRepeat, 1, KPack>,
Sequence<0, 1, 2>, Sequence<0, 1, 2, 3>,
2, 3,
1, 1, // KPack,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
...@@ -208,7 +219,7 @@ template <index_t BlockSize, ...@@ -208,7 +219,7 @@ template <index_t BlockSize,
class BBlockDesc, class BBlockDesc,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave> index_t KPack>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{ {
...@@ -219,7 +230,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -219,7 +230,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPerWave>{}; static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPack>{};
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
...@@ -252,14 +263,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -252,14 +263,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base; const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, m_offset); return make_tuple(k_offset, 0, m_offset, 0);
} }
else else
{ {
const index_t m_offset = waveId_m * MPerWave + laneId; const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0; const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset); return make_tuple(k_offset, 0, m_offset, 0);
} }
} }
...@@ -274,14 +285,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -274,14 +285,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
if constexpr(xdlops_gemm.IsKReduction) if constexpr(xdlops_gemm.IsKReduction)
{ {
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base; const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, n_offset); return make_tuple(k_offset, 0, n_offset, 0);
} }
else else
{ {
const index_t n_offset = waveId_n * NPerWave + laneId; const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0; const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset); return make_tuple(k_offset, 0, n_offset, 0);
} }
} }
...@@ -313,8 +324,19 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -313,8 +324,19 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
"wrong! KPack dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!");
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -331,34 +353,34 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -331,34 +353,34 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0, I0),
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0, I0),
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
...@@ -375,13 +397,13 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -375,13 +397,13 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
0, 0,
1>(a_thread_buf, b_thread_buf, c_thread_buf); 1>(a_thread_buf, b_thread_buf, c_thread_buf);
static_for<KPerWave, KPerBlock, KPerWave>{}([&](auto k) { static_for<xdlops_gemm.KPerXdlops, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
...@@ -393,10 +415,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -393,10 +415,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
...@@ -408,18 +430,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -408,18 +430,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0), make_tuple(k, I1, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0, I0),
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0), make_tuple(k, I1, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0, I0),
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
...@@ -455,11 +477,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -455,11 +477,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerWave>{}, Number<MRepeat>{}, I1)); make_tuple(I1, Number<MRepeat>{}, I1, Number<KPack>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerWave>{}, Number<NRepeat>{}, I1)); make_tuple(I1, Number<NRepeat>{}, I1, Number<KPack>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
...@@ -468,20 +490,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -468,20 +490,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
FloatA, FloatA,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<KPerWave, 1, 1>, Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2>, Sequence<0, 1, 2, 3>,
2, 3,
1, 1, // KPack,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB, FloatB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<KPerWave, 1, 1>, Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2>, Sequence<0, 1, 2, 3>,
2, 3,
1, 1, // KPack,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
......
...@@ -110,7 +110,7 @@ template <index_t BlockSize, ...@@ -110,7 +110,7 @@ template <index_t BlockSize,
index_t KPerBlock, index_t KPerBlock,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave, index_t KPack,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
...@@ -276,7 +276,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -276,7 +276,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS // a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS // b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
...@@ -285,31 +285,35 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -285,31 +285,35 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
NPerBlock % (NPerWave * NRepeat) == 0, NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!"); "wrong!");
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( static_assert(KPerBlock % KPack == 0, "KPerBlock is wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc, a_k_m_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}), make_tuple(
make_unmerge_transform( make_unmerge_transform(make_tuple(Number<KPerBlock / KPack>{}, Number<KPack>{})),
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{}))), make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2>{}));
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc, b_k_n_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}), make_tuple(
make_unmerge_transform( make_unmerge_transform(make_tuple(Number<KPerBlock / KPack>{}, Number<KPack>{})),
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{}))), make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline<BlockSize, BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k_m0_m1_block_desc), decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k_n0_n1_block_desc), decltype(b_k0_n0_n1_k1_block_desc),
MPerWave, MPerWave,
NPerWave, NPerWave,
KPerWave>{}; KPack>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout(); constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize(); constexpr index_t BlkSize = CLayout.GetBlkSize();
......
...@@ -547,7 +547,7 @@ struct xdlops_info ...@@ -547,7 +547,7 @@ struct xdlops_info
static constexpr index_t GetKPerXdlops() static constexpr index_t GetKPerXdlops()
{ {
return mfma_type.k_base * (IsKReduction() ? mfma_type.num_input_blks : 1); return IsKReduction() ? mfma_type.num_input_blks : 1;
} }
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
...@@ -555,7 +555,7 @@ struct xdlops_info ...@@ -555,7 +555,7 @@ struct xdlops_info
static constexpr auto GetCType() { return CType_{}; } static constexpr auto GetCType() { return CType_{}; }
}; };
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPerWave> template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
struct XdlopsGemm struct XdlopsGemm
{ {
template <class base_type_ = base_type, template <class base_type_ = base_type,
...@@ -801,13 +801,13 @@ struct XdlopsGemm ...@@ -801,13 +801,13 @@ struct XdlopsGemm
is_same<base_type, ushort>::value, is_same<base_type, ushort>::value,
"base base_type must be float, half, ushort!"); "base base_type must be float, half, ushort!");
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops"); static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base");
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) { static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(k, m0, 0)); constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(k, n0, 0)); constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>( mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[Number<a_offset>{}], p_b_wave[Number<b_offset>{}], p_c_thread); p_a_wave[Number<a_offset>{}], p_b_wave[Number<b_offset>{}], p_c_thread);
......
...@@ -88,7 +88,7 @@ ...@@ -88,7 +88,7 @@
// experimental implementation // experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif #endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment