"...resnet50_tensorflow.git" did not exist on "589ac399f0777c0c9e71a8a9bdb2b019cefeb536"
Commit c713d224 authored by aska-0096's avatar aska-0096
Browse files

Update low level abstration of blockwise gemm wmma

parent 2ec3f4c3
...@@ -37,8 +37,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -37,8 +37,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault, GemmDefault,
2, // Prefetch stage 2, // Prefetch stage
128, // BlockSize 128, // BlockSize
128, // MPerBlock 128, // MPerBlock
64, // NPerBlock 64, // NPerBlock
64, // KPerBlock 64, // KPerBlock
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
......
...@@ -34,7 +34,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -34,7 +34,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif #endif
const int nrepeat = 50; const int nrepeat = 50;
for(int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
} }
......
...@@ -55,6 +55,7 @@ struct BlockwiseGemmWMMA ...@@ -55,6 +55,7 @@ struct BlockwiseGemmWMMA
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = Number<16>{}; static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -62,8 +63,13 @@ struct BlockwiseGemmWMMA ...@@ -62,8 +63,13 @@ struct BlockwiseGemmWMMA
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32; static constexpr index_t WaveSize = 32;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4); // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4); // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
static constexpr auto wmma_gemm = static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{}; WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
...@@ -71,10 +77,6 @@ struct BlockwiseGemmWMMA ...@@ -71,10 +77,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1;
static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
MRepeat * NRepeat, MRepeat * NRepeat,
...@@ -105,12 +107,12 @@ struct BlockwiseGemmWMMA ...@@ -105,12 +107,12 @@ struct BlockwiseGemmWMMA
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
} }
else else
{ {
return make_tuple(0, 0, 0, 0, 0); return make_tuple(0, 0, 0, 0, 0, 0);
} }
} }
...@@ -122,12 +124,12 @@ struct BlockwiseGemmWMMA ...@@ -122,12 +124,12 @@ struct BlockwiseGemmWMMA
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
} }
else else
{ {
return make_tuple(0, 0, 0, 0, 0); return make_tuple(0, 0, 0, 0, 0, 0);
} }
} }
...@@ -173,9 +175,9 @@ struct BlockwiseGemmWMMA ...@@ -173,9 +175,9 @@ struct BlockwiseGemmWMMA
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]); Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
} }
using Tuple5 = decltype(CalculateAThreadOriginDataIndex()); using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple5 a_origin = CalculateAThreadOriginDataIndex(), __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple5 b_origin = CalculateBThreadOriginDataIndex()) Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin) : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{ {
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
...@@ -224,18 +226,6 @@ struct BlockwiseGemmWMMA ...@@ -224,18 +226,6 @@ struct BlockwiseGemmWMMA
MAccVgprs * AccStride, MAccVgprs * AccStride,
MAccVgprs * AccStride, MAccVgprs * AccStride,
AccStride)); AccStride));
#if 0
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
Number<NRepeat>{},
I1,
NThreadPerSubGroup,
MAccVgprs));
#endif
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
...@@ -312,36 +302,26 @@ struct BlockwiseGemmWMMA ...@@ -312,36 +302,26 @@ struct BlockwiseGemmWMMA
// basic intrinsic to determine loopover direction // basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat) if constexpr(MRepeat < NRepeat)
{ {
static_for<0, KPerBlock / WmmaK, 1>{}(
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
// read A static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{}, make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
m0,
I0,
I0,
I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{}, make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
n0,
I0,
I0,
I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
...@@ -350,12 +330,100 @@ struct BlockwiseGemmWMMA ...@@ -350,12 +330,100 @@ struct BlockwiseGemmWMMA
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) = a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}]; make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) = b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}]; make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
}); });
#if 0
if (get_thread_local_1d_id() == 0){
printf("repeat: m,n,k:(%02d, %02d, %02d) a_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(0 / A_K1, m0, 0, 0, 0 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(1 / A_K1, m0, 0, 0, 1 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(2 / A_K1, m0, 0, 0, 2 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(3 / A_K1, m0, 0, 0, 3 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(4 / A_K1, m0, 0, 0, 4 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(5 / A_K1, m0, 0, 0, 5% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(6 / A_K1, m0, 0, 0, 6 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(7 / A_K1, m0, 0, 0, 7 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(8 / A_K1, m0, 0, 0, 8 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(9 / A_K1, m0, 0, 0, 9% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(10 / A_K1, m0, 0, 0, 10 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(11 / A_K1, m0, 0, 0, 11 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(12 / A_K1, m0, 0, 0, 12 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(13 / A_K1, m0, 0, 0, 13 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(14 / A_K1, m0, 0, 0, 14 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(15 / A_K1, m0, 0, 0, 15% A_K1))>{}])))
);
}
// if (get_thread_local_1d_id() == 0){
// printf("repeat: m,n,k:(%02d, %02d, %02d) b_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
// m0.value, n0.value, k.value,
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(0 / B_K1, n0, 0, 0, 0 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(1 / B_K1, n0, 0, 0, 1 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(2 / B_K1, n0, 0, 0, 2 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(3 / B_K1, n0, 0, 0, 3 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(4 / B_K1, n0, 0, 0, 4 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(5 / B_K1, n0, 0, 0, 5% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(6 / B_K1, n0, 0, 0, 6 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(7 / B_K1, n0, 0, 0, 7 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(8 / B_K1, n0, 0, 0, 8 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(9 / B_K1, n0, 0, 0, 9% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(10 / B_K1, n0, 0, 0, 10 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(11 / B_K1, n0, 0, 0, 11 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(12 / B_K1, n0, 0, 0, 12 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(13 / B_K1, n0, 0, 0, 13 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(14 / B_K1, n0, 0, 0, 14 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(15 / B_K1, n0, 0, 0, 15% B_K1))>{}])))
// );
// }
#endif
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
...@@ -372,36 +440,26 @@ struct BlockwiseGemmWMMA ...@@ -372,36 +440,26 @@ struct BlockwiseGemmWMMA
} }
else else
{ {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPerBlock / WmmaK, 1>{}( static_for<0, MRepeat, 1>{}([&](auto m0) {
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... static_for<0, KPerBlock / WmmaK, 1>{}(
// read B [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
b_thread_copy_.Run( // read B
b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(
make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{}, b_block_desc_k0_n0_n1_n2_k1,
n0, make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
I0, b_block_buf,
I0, b_thread_desc_,
I0), make_tuple(I0, n0, I0, I0, I0, I0),
b_block_buf, b_thread_buf);
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{}, make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
m0,
I0,
I0,
I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
...@@ -410,10 +468,20 @@ struct BlockwiseGemmWMMA ...@@ -410,10 +468,20 @@ struct BlockwiseGemmWMMA
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) = b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}]; make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
a_thread_vec.template AsType<FloatA>()(i) = a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}]; make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
...@@ -427,20 +495,39 @@ struct BlockwiseGemmWMMA ...@@ -427,20 +495,39 @@ struct BlockwiseGemmWMMA
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
});
}); });
});
} }
} }
protected: protected:
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( static constexpr auto a_thread_desc_ =
make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}), make_naive_tensor_descriptor(make_tuple(Number<WmmaK / A_K1 / A_KRow>{},
make_tuple(Number<A_K1>{}, Number<WmmaK>{}, Number<A_K1>{}, Number<A_K1>{}, Number<1>{})); Number<MRepeat>{},
I1,
// B[K0, N0, N1, N2, K1] Number<A_KRow>{},
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( I1,
make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}), Number<A_K1>{}),
make_tuple(Number<B_K1>{}, Number<WmmaK>{}, Number<B_K1>{}, Number<B_K1>{}, Number<1>{})); make_tuple(Number<A_K1 * A_KRow>{},
Number<WmmaK>{},
Number<A_K1 * A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<WmmaK>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -452,15 +539,16 @@ struct BlockwiseGemmWMMA ...@@ -452,15 +539,16 @@ struct BlockwiseGemmWMMA
template <> template <>
struct AThreadCopySelector<true> struct AThreadCopySelector<true>
{ {
using type = ThreadwiseTensorSliceTransfer_v4<FloatA, using type =
FloatA, ThreadwiseTensorSliceTransfer_v4<FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), FloatA,
decltype(a_thread_desc_), decltype(a_block_desc_k0_m0_m1_m2_k1),
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>, decltype(a_thread_desc_),
Sequence<0, 1, 2, 3, 4>, Sequence<WmmaK / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
4, Sequence<0, 1, 2, 3, 4, 5>,
A_K1, 5,
A_K1>; A_K1,
A_K1>;
}; };
template <> template <>
...@@ -472,9 +560,9 @@ struct BlockwiseGemmWMMA ...@@ -472,9 +560,9 @@ struct BlockwiseGemmWMMA
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, A_K1>, Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>, Sequence<0, 1, 2, 3, 4, 5>,
4, 5,
A_K1, A_K1,
0x76543210, 0x76543210,
0xfedcba98, 0xfedcba98,
...@@ -487,15 +575,16 @@ struct BlockwiseGemmWMMA ...@@ -487,15 +575,16 @@ struct BlockwiseGemmWMMA
template <> template <>
struct BThreadCopySelector<true> struct BThreadCopySelector<true>
{ {
using type = ThreadwiseTensorSliceTransfer_v4<FloatB, using type =
FloatB, ThreadwiseTensorSliceTransfer_v4<FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), FloatB,
decltype(b_thread_desc_), decltype(b_block_desc_k0_n0_n1_n2_k1),
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>, decltype(b_thread_desc_),
Sequence<0, 1, 2, 3, 4>, Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
4, Sequence<0, 1, 2, 3, 4, 5>,
B_K1, 5,
B_K1>; B_K1,
B_K1>;
}; };
template <> template <>
...@@ -507,9 +596,9 @@ struct BlockwiseGemmWMMA ...@@ -507,9 +596,9 @@ struct BlockwiseGemmWMMA
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, B_K1>, Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>, Sequence<0, 1, 2, 3, 4, 5>,
4, 5,
B_K1, B_K1,
0x76543210, 0x76543210,
0xfedcba98, 0xfedcba98,
......
...@@ -80,6 +80,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -80,6 +80,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
...@@ -136,18 +137,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -136,18 +137,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} }
else else
{ {
constexpr auto A_KRow = WmmaK / K1; constexpr auto A_KRow = 2;
const auto A_KWmma = K / WmmaK; constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock; const auto M0 = M / MPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))), make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -187,18 +191,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -187,18 +191,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} }
else else
{ {
constexpr auto B_KRow = WmmaK / K1; constexpr auto B_KRow = 2;
const auto B_KWmma = K / WmmaK; constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock; const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_n_k, b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))), make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -372,7 +379,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -372,7 +379,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
else else
{ {
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5); arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
} }
}(); }();
......
...@@ -297,15 +297,7 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -297,15 +297,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
#if 0 constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf; auto a_block_buf_switch = a_block_buf;
// preload data into LDS // preload data into LDS
...@@ -404,7 +396,7 @@ struct GridwiseGemmPipeline_v1<1, true, false> ...@@ -404,7 +396,7 @@ struct GridwiseGemmPipeline_v1<1, true, false>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0); constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf; auto b_block_buf_switch = b_block_buf;
// preload data into LDS // preload data into LDS
......
...@@ -172,10 +172,23 @@ struct GridwiseGemm_Wmma ...@@ -172,10 +172,23 @@ struct GridwiseGemm_Wmma
else else
{ {
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1), make_tuple(Number<KWmmaPerblock>{},
make_tuple(Number<MRepeat>{} * K1, K1, K1, K1, K1, I1)); Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
} }
}(); }();
...@@ -206,10 +219,23 @@ struct GridwiseGemm_Wmma ...@@ -206,10 +219,23 @@ struct GridwiseGemm_Wmma
else else
{ {
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, K1), make_tuple(Number<KWmmaPerblock>{},
make_tuple(Number<NRepeat>{} * K1, K1, K1, K1, K1, I1)); Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
} }
}(); }();
...@@ -229,7 +255,7 @@ struct GridwiseGemm_Wmma ...@@ -229,7 +255,7 @@ struct GridwiseGemm_Wmma
{ {
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0); return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -249,7 +275,7 @@ struct GridwiseGemm_Wmma ...@@ -249,7 +275,7 @@ struct GridwiseGemm_Wmma
{ {
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0); return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -264,23 +290,26 @@ struct GridwiseGemm_Wmma ...@@ -264,23 +290,26 @@ struct GridwiseGemm_Wmma
constexpr auto a_wave_desc = [&]() { constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
else else
{ {
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5); constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
// Err: merge transform cause non-constexpr issue // Err: merge transform cause non-constexpr issue
...@@ -301,26 +330,12 @@ struct GridwiseGemm_Wmma ...@@ -301,26 +330,12 @@ struct GridwiseGemm_Wmma
// Sequence<4>{})); // Sequence<4>{}));
// Workaround, Freeze transform // Workaround, Freeze transform
return transform_tensor_descriptor( return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
ABlockDesc_{}, Number<MRepeat>{},
make_tuple(make_freeze_transform(I0), I1,
make_pass_through_transform(Number<KWmma>{}), Number<A_KRow>{},
make_pass_through_transform(Number<MRepeat>{}), I1,
make_pass_through_transform(I1), Number<A_K1>{}));
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -334,44 +349,33 @@ struct GridwiseGemm_Wmma ...@@ -334,44 +349,33 @@ struct GridwiseGemm_Wmma
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
else else
{ {
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I5); constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform // Workaround, Freeze transform
return transform_tensor_descriptor( return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
BBlockDesc_{}, Number<NRepeat>{},
make_tuple(make_freeze_transform(I0), I1,
make_pass_through_transform(Number<KWmma>{}), Number<B_KRow>{},
make_pass_through_transform(Number<NRepeat>{}), I1,
make_pass_through_transform(I1), Number<B_K1>{}));
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -415,9 +419,9 @@ struct GridwiseGemm_Wmma ...@@ -415,9 +419,9 @@ struct GridwiseGemm_Wmma
else else
{ {
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4), a_grid_desc.GetLength(I5),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5)); a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
} }
}; };
...@@ -430,9 +434,9 @@ struct GridwiseGemm_Wmma ...@@ -430,9 +434,9 @@ struct GridwiseGemm_Wmma
else else
{ {
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4), b_grid_desc.GetLength(I5),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5)); b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
} }
}; };
...@@ -599,7 +603,8 @@ struct GridwiseGemm_Wmma ...@@ -599,7 +603,8 @@ struct GridwiseGemm_Wmma
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
} }
else{ else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5); return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
* a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
} }
}(); }();
...@@ -652,6 +657,7 @@ struct GridwiseGemm_Wmma ...@@ -652,6 +657,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>( auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize()); a_block_desc.GetElementSpaceSize());
...@@ -664,11 +670,12 @@ struct GridwiseGemm_Wmma ...@@ -664,11 +670,12 @@ struct GridwiseGemm_Wmma
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{}, Number<MRepeat>{},
I1, I1,
Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
...@@ -676,6 +683,7 @@ struct GridwiseGemm_Wmma ...@@ -676,6 +683,7 @@ struct GridwiseGemm_Wmma
make_multi_index(0, make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma), m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32, get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16, (get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
...@@ -729,6 +737,7 @@ struct GridwiseGemm_Wmma ...@@ -729,6 +737,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>( auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize()); b_block_desc.GetElementSpaceSize());
...@@ -741,11 +750,12 @@ struct GridwiseGemm_Wmma ...@@ -741,11 +750,12 @@ struct GridwiseGemm_Wmma
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{}, Number<NRepeat>{},
I1, I1,
Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
...@@ -753,6 +763,7 @@ struct GridwiseGemm_Wmma ...@@ -753,6 +763,7 @@ struct GridwiseGemm_Wmma
make_multi_index(0, make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma), n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32, get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16, (get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
......
...@@ -1387,7 +1387,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1387,7 +1387,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector // copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// src_desc error, non constexpr? // src_desc error, non constexpr, caused by merge transform
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
...@@ -1396,8 +1396,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1396,8 +1396,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
SrcData v_this_row, v_theother_row; SrcData v_this_row, v_theother_row;
// int type temp value due to intrinsic requirement // int type temp value due to intrinsic requirement
// TODO: This temp value will generate the scratch memory if
// IntraRowSwizzlePerm is flase
int temp = 0; int temp = 0;
// apply element-wise operation // apply element-wise operation
...@@ -1419,7 +1417,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1419,7 +1417,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
1, 1,
0); 0);
v_theother_row = type_convert_sp<SrcData>(temp); v_theother_row = type_convert_sp<SrcData>(temp);
// if (get_thread_local_1d_id() == 0){
// printf("src_offset:%d, dst_offset for this row: %d, dst_offset
// for the other row: %d \n",
// src_offset, dst_offset, dst_offset+DstScalarPerVector);}
if(get_thread_local_1d_id() % 32 < 16) if(get_thread_local_1d_id() % 32 < 16)
{ {
// apply type convert // apply type convert
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment