Commit f3111877 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent 76bb51f4
......@@ -67,11 +67,11 @@ struct BlockwiseGemmWMMA
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
#ifdef __gfx12__
static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2;
#else
static constexpr index_t A_KRow = 1;
static constexpr index_t B_KRow = 1;
#else
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
#endif
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
......@@ -563,6 +563,7 @@ struct BlockwiseGemmWMMA
#endif
protected:
#ifdef __gfx12__
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{},
......@@ -580,6 +581,35 @@ struct BlockwiseGemmWMMA
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
#else
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1 * A_KRow>{},
Number<KPack>{},
Number<A_K1 * A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<KPack>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
#endif
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......@@ -610,7 +640,7 @@ struct BlockwiseGemmWMMA
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
......@@ -619,7 +649,10 @@ struct BlockwiseGemmWMMA
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1>;
A_K1,
0x76543210,
0xfedcba98,
TransposeC ? false : true>;
};
template <bool EnableLds>
......@@ -647,7 +680,7 @@ struct BlockwiseGemmWMMA
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
......@@ -656,7 +689,10 @@ struct BlockwiseGemmWMMA
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1>;
B_K1,
0x76543210,
0xfedcba98,
TransposeC ? true : false>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
......
......@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment