Commit f221c68e authored by Jing Zhang's avatar Jing Zhang
Browse files

merge navi3_ref

parent 37560a6d
...@@ -66,8 +66,8 @@ struct BlockwiseGemmWMMA ...@@ -66,8 +66,8 @@ struct BlockwiseGemmWMMA
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation // permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2; static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2; static constexpr index_t B_KRow = 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
...@@ -213,19 +213,20 @@ struct BlockwiseGemmWMMA ...@@ -213,19 +213,20 @@ struct BlockwiseGemmWMMA
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
return make_naive_tensor_descriptor( constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave // |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs // |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs), make_tuple(Number<MRepeat>{},
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride, I1,
Number<NRepeat>{} * MAccVgprs * AccStride, MSubGroup,
Number<NRepeat>{} * MAccVgprs * AccStride, Number<NRepeat>{},
MAccVgprs * AccStride, I1,
MAccVgprs * AccStride, NThreadPerSubGroup,
MAccVgprs * AccStride, MAccVgprs));
AccStride));
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
...@@ -324,30 +325,25 @@ struct BlockwiseGemmWMMA ...@@ -324,30 +325,25 @@ struct BlockwiseGemmWMMA
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
vector_type<FloatA, KPack> a_thread_vec; vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec; vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) = a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow, make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
m0, });
0,
(i / A_K1) % A_KRow, static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
0,
i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) = b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow, make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a =
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -383,30 +379,25 @@ struct BlockwiseGemmWMMA ...@@ -383,30 +379,25 @@ struct BlockwiseGemmWMMA
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
vector_type<FloatA, KPack> a_thread_vec; vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec; vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
a_thread_vec.template AsType<FloatA>()(i) = a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow, make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
m0, });
0,
(i / A_K1) % A_KRow, static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
0, b_thread_vec.template AsType<FloatB>()(i) =
i % A_K1))>{}]; b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a =
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -422,33 +413,23 @@ struct BlockwiseGemmWMMA ...@@ -422,33 +413,23 @@ struct BlockwiseGemmWMMA
} }
protected: protected:
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{}, make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
Number<MRepeat>{}, make_tuple(Number<A_K1>{},
I1, Number<KPack / A_KRow>{},
Number<A_KRow>{}, Number<A_K1>{},
I1, Number<A_K1>{},
Number<A_K1>{}), Number<A_K1>{},
make_tuple(Number<A_K1 * A_KRow>{}, Number<1>{}));
Number<KPack>{},
Number<A_K1 * A_KRow>{}, static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
Number<A_K1>{}, make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
Number<A_K1>{}, make_tuple(Number<B_K1>{},
Number<1>{})); Number<KPack / B_KRow>{},
Number<B_K1>{},
static constexpr auto b_thread_desc_ = Number<B_K1>{},
make_naive_tensor_descriptor(make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<B_K1>{},
Number<NRepeat>{}, Number<1>{}));
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<KPack>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -465,7 +446,7 @@ struct BlockwiseGemmWMMA ...@@ -465,7 +446,7 @@ struct BlockwiseGemmWMMA
FloatA, FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>, Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
A_K1, A_K1,
...@@ -475,7 +456,7 @@ struct BlockwiseGemmWMMA ...@@ -475,7 +456,7 @@ struct BlockwiseGemmWMMA
template <> template <>
struct AThreadCopySelector<false> struct AThreadCopySelector<false>
{ {
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatA, FloatA,
FloatA, FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
...@@ -484,10 +465,7 @@ struct BlockwiseGemmWMMA ...@@ -484,10 +465,7 @@ struct BlockwiseGemmWMMA
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>, Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
A_K1, A_K1>;
0x76543210,
0xfedcba98,
TransposeC ? false : true>;
}; };
template <bool EnableLds> template <bool EnableLds>
...@@ -501,7 +479,7 @@ struct BlockwiseGemmWMMA ...@@ -501,7 +479,7 @@ struct BlockwiseGemmWMMA
FloatB, FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>, Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1, B_K1,
...@@ -511,7 +489,7 @@ struct BlockwiseGemmWMMA ...@@ -511,7 +489,7 @@ struct BlockwiseGemmWMMA
template <> template <>
struct BThreadCopySelector<false> struct BThreadCopySelector<false>
{ {
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatB, FloatB,
FloatB, FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
...@@ -520,10 +498,7 @@ struct BlockwiseGemmWMMA ...@@ -520,10 +498,7 @@ struct BlockwiseGemmWMMA
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>, Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1, B_K1>;
0x76543210,
0xfedcba98,
TransposeC ? true : false>;
}; };
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_; typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
......
...@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12, ...@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// Wave mode dependent propety // Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{}; static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x // * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction // * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction // * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave = static constexpr index_t num_acc_vgprs_per_wave =
...@@ -390,7 +390,7 @@ struct WmmaSelector ...@@ -390,7 +390,7 @@ struct WmmaSelector
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size * selected_wmma.acc_pack_number == selected_wmma.acc_data_size ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register"); "WRONG! Invalid Number of Accumulator Register");
} }
...@@ -510,7 +510,7 @@ struct WmmaGemm ...@@ -510,7 +510,7 @@ struct WmmaGemm
__device__ static constexpr index_t GetRegSizePerWmma() __device__ static constexpr index_t GetRegSizePerWmma()
{ {
return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number; return wmma_instr.num_acc_vgprs_per_wave;
} }
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
...@@ -566,12 +566,14 @@ struct WmmaGemm ...@@ -566,12 +566,14 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex() __host__ __device__ static auto CalculateAThreadOriginDataIndex()
{ {
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); // return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
return GetLaneIdUnderSubGroup();
} }
__host__ __device__ static auto CalculateBThreadOriginDataIndex() __host__ __device__ static auto CalculateBThreadOriginDataIndex()
{ {
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); // return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
return GetLaneIdUnderSubGroup();
} }
__device__ static CIndex GetBeginOfThreadBlk() __device__ static CIndex GetBeginOfThreadBlk()
...@@ -597,10 +599,7 @@ struct WmmaGemm ...@@ -597,10 +599,7 @@ struct WmmaGemm
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{ {
return make_tuple(I1, return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
I1,
Number<wmma_instr.num_acc_vgprs_per_wave>{},
Number<wmma_instr.acc_pack_number>{});
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment