Commit f3111877 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent 76bb51f4
...@@ -67,11 +67,11 @@ struct BlockwiseGemmWMMA ...@@ -67,11 +67,11 @@ struct BlockwiseGemmWMMA
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation // permutation
#ifdef __gfx12__ #ifdef __gfx12__
static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2;
#else
static constexpr index_t A_KRow = 1; static constexpr index_t A_KRow = 1;
static constexpr index_t B_KRow = 1; static constexpr index_t B_KRow = 1;
#else
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
#endif #endif
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
...@@ -563,6 +563,7 @@ struct BlockwiseGemmWMMA ...@@ -563,6 +563,7 @@ struct BlockwiseGemmWMMA
#endif #endif
protected: protected:
#ifdef __gfx12__
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}), make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{}, make_tuple(Number<A_K1>{},
...@@ -580,6 +581,35 @@ struct BlockwiseGemmWMMA ...@@ -580,6 +581,35 @@ struct BlockwiseGemmWMMA
Number<B_K1>{}, Number<B_K1>{},
Number<B_K1>{}, Number<B_K1>{},
Number<1>{})); Number<1>{}));
#else
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1 * A_KRow>{},
Number<KPack>{},
Number<A_K1 * A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<KPack>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
#endif
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -610,7 +640,7 @@ struct BlockwiseGemmWMMA ...@@ -610,7 +640,7 @@ struct BlockwiseGemmWMMA
template <> template <>
struct AThreadCopySelector<false> struct AThreadCopySelector<false>
{ {
using type = ThreadwiseTensorSliceTransfer_StaticToStatic< using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatA, FloatA,
FloatA, FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
...@@ -619,7 +649,10 @@ struct BlockwiseGemmWMMA ...@@ -619,7 +649,10 @@ struct BlockwiseGemmWMMA
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>, Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
A_K1>; A_K1,
0x76543210,
0xfedcba98,
TransposeC ? false : true>;
}; };
template <bool EnableLds> template <bool EnableLds>
...@@ -647,7 +680,7 @@ struct BlockwiseGemmWMMA ...@@ -647,7 +680,7 @@ struct BlockwiseGemmWMMA
template <> template <>
struct BThreadCopySelector<false> struct BThreadCopySelector<false>
{ {
using type = ThreadwiseTensorSliceTransfer_StaticToStatic< using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatB, FloatB,
FloatB, FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
...@@ -656,7 +689,10 @@ struct BlockwiseGemmWMMA ...@@ -656,7 +689,10 @@ struct BlockwiseGemmWMMA
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>, Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1>; B_K1,
0x76543210,
0xfedcba98,
TransposeC ? true : false>;
}; };
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_; typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
......
...@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma ...@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment