Commit 12a4ea69 authored by aska-0096's avatar aska-0096
Browse files

(3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds

parent 3ccfb0ae
...@@ -67,33 +67,33 @@ using DeviceOpInstanceKKNN = ...@@ -67,33 +67,33 @@ using DeviceOpInstanceKKNN =
ASpec, ASpec,
BSpec, BSpec,
DESpec, DESpec,
2, 1,
128, 128,
64, 64,
128, 64,
32, 64,
8, 4,
16, 16,
16, 16,
2, 1,
4, 4,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 4,
8, 4,
true, true,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 4,
8, 4,
true, true,
1, 1,
1, 1,
S<1, 32, 1, 4>, S<1, 64, 1, 2>,
8>; 8>;
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
......
...@@ -305,7 +305,7 @@ struct BlockwiseGemmWMMA ...@@ -305,7 +305,7 @@ struct BlockwiseGemmWMMA
static_for<0, KPerBlock / WmmaK, 1>{}( static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
...@@ -346,84 +346,6 @@ struct BlockwiseGemmWMMA ...@@ -346,84 +346,6 @@ struct BlockwiseGemmWMMA
i % B_K1))>{}]; i % B_K1))>{}];
}); });
#if 0
if (get_thread_local_1d_id() == 0){
printf("repeat: m,n,k:(%02d, %02d, %02d) a_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(0 / A_K1, m0, 0, 0, 0 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(1 / A_K1, m0, 0, 0, 1 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(2 / A_K1, m0, 0, 0, 2 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(3 / A_K1, m0, 0, 0, 3 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(4 / A_K1, m0, 0, 0, 4 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(5 / A_K1, m0, 0, 0, 5% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(6 / A_K1, m0, 0, 0, 6 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(7 / A_K1, m0, 0, 0, 7 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(8 / A_K1, m0, 0, 0, 8 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(9 / A_K1, m0, 0, 0, 9% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(10 / A_K1, m0, 0, 0, 10 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(11 / A_K1, m0, 0, 0, 11 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(12 / A_K1, m0, 0, 0, 12 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(13 / A_K1, m0, 0, 0, 13 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(14 / A_K1, m0, 0, 0, 14 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(15 / A_K1, m0, 0, 0, 15% A_K1))>{}])))
);
}
// if (get_thread_local_1d_id() == 0){
// printf("repeat: m,n,k:(%02d, %02d, %02d) b_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
// m0.value, n0.value, k.value,
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(0 / B_K1, n0, 0, 0, 0 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(1 / B_K1, n0, 0, 0, 1 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(2 / B_K1, n0, 0, 0, 2 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(3 / B_K1, n0, 0, 0, 3 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(4 / B_K1, n0, 0, 0, 4 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(5 / B_K1, n0, 0, 0, 5% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(6 / B_K1, n0, 0, 0, 6 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(7 / B_K1, n0, 0, 0, 7 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(8 / B_K1, n0, 0, 0, 8 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(9 / B_K1, n0, 0, 0, 9% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(10 / B_K1, n0, 0, 0, 10 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(11 / B_K1, n0, 0, 0, 11 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(12 / B_K1, n0, 0, 0, 12 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(13 / B_K1, n0, 0, 0, 13 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(14 / B_K1, n0, 0, 0, 14 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(15 / B_K1, n0, 0, 0, 15% B_K1))>{}])))
// );
// }
#endif
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
......
...@@ -125,6 +125,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -125,6 +125,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
...@@ -136,9 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -136,9 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
// Bug, MNK vector load check not implemented correctly static constexpr auto AEnableLds_manu = false;
static constexpr auto AEnableLds_manu = true; static constexpr auto BEnableLds_manu = false;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
...@@ -220,18 +220,21 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -220,18 +220,21 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
} }
else else
{ {
constexpr auto A_KRow = WmmaK / K1; constexpr auto A_KRow = 2;
const auto A_KWmma = K / WmmaK; constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock; const auto M0 = M / MPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))), make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -309,18 +312,21 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -309,18 +312,21 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
} }
else else
{ {
constexpr auto B_KRow = WmmaK / K1; constexpr auto B_KRow = 2;
const auto B_KWmma = K / WmmaK; constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock; const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_n_k, b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))), make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -752,7 +758,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -752,7 +758,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
else else
{ {
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5); arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
} }
}(); }();
...@@ -1036,7 +1042,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -1036,7 +1042,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
<< MRepeat << ", " << MRepeat << ", "
<< NRepeat << NRepeat
<< ">" << ">"
<< " NumPrefetch: " << " AEnableLds: "
<< AEnableLds << ", "
<< "BEnableLds: "
<< BEnableLds << ", "
<< "NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
<< "LoopScheduler: " << "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment