Commit c713d224 authored by aska-0096's avatar aska-0096
Browse files

Update low level abstration of blockwise gemm wmma

parent 2ec3f4c3
......@@ -37,8 +37,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault,
2, // Prefetch stage
128, // BlockSize
128, // MPerBlock
64, // NPerBlock
128, // MPerBlock
64, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
......
......@@ -34,7 +34,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif
const int nrepeat = 50;
for(int i = 0; i < nrepeat; ++i)
for(int i = 0; i < nrepeat; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
}
......
......@@ -55,6 +55,7 @@ struct BlockwiseGemmWMMA
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -62,8 +63,13 @@ struct BlockwiseGemmWMMA
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4);
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
......@@ -71,10 +77,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1;
static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
......@@ -105,12 +107,12 @@ struct BlockwiseGemmWMMA
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0);
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
......@@ -122,12 +124,12 @@ struct BlockwiseGemmWMMA
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0);
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
......@@ -173,9 +175,9 @@ struct BlockwiseGemmWMMA
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
}
using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple5 a_origin = CalculateAThreadOriginDataIndex(),
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
......@@ -224,18 +226,6 @@ struct BlockwiseGemmWMMA
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride));
#if 0
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
Number<NRepeat>{},
I1,
NThreadPerSubGroup,
MAccVgprs));
#endif
}
template <typename CGridDesc_M_N>
......@@ -312,36 +302,26 @@ struct BlockwiseGemmWMMA
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPerBlock / WmmaK, 1>{}(
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
// read A
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{},
m0,
I0,
I0,
I0),
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{},
n0,
I0,
I0,
I0),
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
......@@ -350,12 +330,100 @@ struct BlockwiseGemmWMMA
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
});
#if 0
if (get_thread_local_1d_id() == 0){
printf("repeat: m,n,k:(%02d, %02d, %02d) a_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(0 / A_K1, m0, 0, 0, 0 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(1 / A_K1, m0, 0, 0, 1 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(2 / A_K1, m0, 0, 0, 2 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(3 / A_K1, m0, 0, 0, 3 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(4 / A_K1, m0, 0, 0, 4 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(5 / A_K1, m0, 0, 0, 5% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(6 / A_K1, m0, 0, 0, 6 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(7 / A_K1, m0, 0, 0, 7 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(8 / A_K1, m0, 0, 0, 8 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(9 / A_K1, m0, 0, 0, 9% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(10 / A_K1, m0, 0, 0, 10 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(11 / A_K1, m0, 0, 0, 11 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(12 / A_K1, m0, 0, 0, 12 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(13 / A_K1, m0, 0, 0, 13 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(14 / A_K1, m0, 0, 0, 14 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(15 / A_K1, m0, 0, 0, 15% A_K1))>{}])))
);
}
// if (get_thread_local_1d_id() == 0){
// printf("repeat: m,n,k:(%02d, %02d, %02d) b_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
// m0.value, n0.value, k.value,
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(0 / B_K1, n0, 0, 0, 0 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(1 / B_K1, n0, 0, 0, 1 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(2 / B_K1, n0, 0, 0, 2 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(3 / B_K1, n0, 0, 0, 3 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(4 / B_K1, n0, 0, 0, 4 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(5 / B_K1, n0, 0, 0, 5% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(6 / B_K1, n0, 0, 0, 6 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(7 / B_K1, n0, 0, 0, 7 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(8 / B_K1, n0, 0, 0, 8 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(9 / B_K1, n0, 0, 0, 9% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(10 / B_K1, n0, 0, 0, 10 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(11 / B_K1, n0, 0, 0, 11 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(12 / B_K1, n0, 0, 0, 12 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(13 / B_K1, n0, 0, 0, 13 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(14 / B_K1, n0, 0, 0, 14 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(15 / B_K1, n0, 0, 0, 15% B_K1))>{}])))
// );
// }
#endif
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
......@@ -372,36 +440,26 @@ struct BlockwiseGemmWMMA
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{},
n0,
I0,
I0,
I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{},
m0,
I0,
I0,
I0),
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
......@@ -410,10 +468,20 @@ struct BlockwiseGemmWMMA
static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
......@@ -427,20 +495,39 @@ struct BlockwiseGemmWMMA
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
}
}
protected:
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{}, Number<WmmaK>{}, Number<A_K1>{}, Number<A_K1>{}, Number<1>{}));
// B[K0, N0, N1, N2, K1]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{}, Number<WmmaK>{}, Number<B_K1>{}, Number<B_K1>{}, Number<1>{}));
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / A_K1 / A_KRow>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1 * A_KRow>{},
Number<WmmaK>{},
Number<A_K1 * A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<WmmaK>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......@@ -452,15 +539,16 @@ struct BlockwiseGemmWMMA
template <>
struct AThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
A_K1>;
using type =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
};
template <>
......@@ -472,9 +560,9 @@ struct BlockwiseGemmWMMA
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
0x76543210,
0xfedcba98,
......@@ -487,15 +575,16 @@ struct BlockwiseGemmWMMA
template <>
struct BThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
B_K1>;
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
template <>
......@@ -507,9 +596,9 @@ struct BlockwiseGemmWMMA
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
0x76543210,
0xfedcba98,
......
......@@ -80,6 +80,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
......@@ -136,18 +137,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
else
{
constexpr auto A_KRow = WmmaK / K1;
const auto A_KWmma = K / WmmaK;
constexpr auto A_KRow = 2;
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_tuple(make_unmerge_transform(make_tuple(
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
......@@ -187,18 +191,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
constexpr auto B_KRow = 2;
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
......@@ -372,7 +379,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
}
}();
......
......@@ -297,15 +297,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
#if 0
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
......@@ -404,7 +396,7 @@ struct GridwiseGemmPipeline_v1<1, true, false>
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
......
......@@ -172,10 +172,23 @@ struct GridwiseGemm_Wmma
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<MRepeat>{} * K1, K1, K1, K1, K1, I1));
make_tuple(Number<KWmmaPerblock>{},
Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
}
}();
......@@ -206,10 +219,23 @@ struct GridwiseGemm_Wmma
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<NRepeat>{} * K1, K1, K1, K1, K1, I1));
make_tuple(Number<KWmmaPerblock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
}
}();
......@@ -229,7 +255,7 @@ struct GridwiseGemm_Wmma
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
}
}();
......@@ -249,7 +275,7 @@ struct GridwiseGemm_Wmma
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
}
}();
......@@ -264,23 +290,26 @@ struct GridwiseGemm_Wmma
constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
// Err: merge transform cause non-constexpr issue
......@@ -301,26 +330,12 @@ struct GridwiseGemm_Wmma
// Sequence<4>{}));
// Workaround, Freeze transform
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}));
}
}();
......@@ -334,44 +349,33 @@ struct GridwiseGemm_Wmma
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I5);
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<NRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}));
}
}();
......@@ -415,9 +419,9 @@ struct GridwiseGemm_Wmma
else
{
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4),
a_grid_desc.GetLength(I5),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5));
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
}
};
......@@ -430,9 +434,9 @@ struct GridwiseGemm_Wmma
else
{
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4),
b_grid_desc.GetLength(I5),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5));
b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
}
};
......@@ -599,7 +603,8 @@ struct GridwiseGemm_Wmma
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
* a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
}
}();
......@@ -652,6 +657,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize());
......@@ -664,11 +670,12 @@ struct GridwiseGemm_Wmma
Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
......@@ -676,6 +683,7 @@ struct GridwiseGemm_Wmma
make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
......@@ -729,6 +737,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize());
......@@ -741,11 +750,12 @@ struct GridwiseGemm_Wmma
Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
......@@ -753,6 +763,7 @@ struct GridwiseGemm_Wmma
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
......
......@@ -1387,7 +1387,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// src_desc error, non constexpr?
// src_desc error, non constexpr, caused by merge transform
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
......@@ -1396,8 +1396,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
SrcData v_this_row, v_theother_row;
// int type temp value due to intrinsic requirement
// TODO: This temp value will generate the scratch memory if
// IntraRowSwizzlePerm is flase
int temp = 0;
// apply element-wise operation
......@@ -1419,7 +1417,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
1,
0);
v_theother_row = type_convert_sp<SrcData>(temp);
// if (get_thread_local_1d_id() == 0){
// printf("src_offset:%d, dst_offset for this row: %d, dst_offset
// for the other row: %d \n",
// src_offset, dst_offset, dst_offset+DstScalarPerVector);}
if(get_thread_local_1d_id() % 32 < 16)
{
// apply type convert
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment