Commit f221c68e authored by Jing Zhang's avatar Jing Zhang
Browse files

merge navi3_ref

parent 37560a6d
......@@ -66,8 +66,8 @@ struct BlockwiseGemmWMMA
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
......@@ -213,19 +213,20 @@ struct BlockwiseGemmWMMA
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride));
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
Number<NRepeat>{},
I1,
NThreadPerSubGroup,
MAccVgprs));
}
template <typename CGridDesc_M_N>
......@@ -324,30 +325,25 @@ struct BlockwiseGemmWMMA
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
using wmma_input_type_a =
typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -383,30 +379,25 @@ struct BlockwiseGemmWMMA
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
using wmma_input_type_a =
typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -422,33 +413,23 @@ struct BlockwiseGemmWMMA
}
protected:
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1 * A_KRow>{},
Number<KPack>{},
Number<A_K1 * A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<KPack>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......@@ -465,7 +446,7 @@ struct BlockwiseGemmWMMA
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
......@@ -475,7 +456,7 @@ struct BlockwiseGemmWMMA
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
......@@ -484,10 +465,7 @@ struct BlockwiseGemmWMMA
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
0x76543210,
0xfedcba98,
TransposeC ? false : true>;
A_K1>;
};
template <bool EnableLds>
......@@ -501,7 +479,7 @@ struct BlockwiseGemmWMMA
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
......@@ -511,7 +489,7 @@ struct BlockwiseGemmWMMA
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
......@@ -520,10 +498,7 @@ struct BlockwiseGemmWMMA
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
0x76543210,
0xfedcba98,
TransposeC ? true : false>;
B_K1>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
......
......@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave =
......@@ -390,7 +390,7 @@ struct WmmaSelector
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
selected_wmma.acc_data_size ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register");
}
......@@ -510,7 +510,7 @@ struct WmmaGemm
__device__ static constexpr index_t GetRegSizePerWmma()
{
return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
return wmma_instr.num_acc_vgprs_per_wave;
}
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
......@@ -566,12 +566,14 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
// return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
return GetLaneIdUnderSubGroup();
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
// return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
return GetLaneIdUnderSubGroup();
}
__device__ static CIndex GetBeginOfThreadBlk()
......@@ -597,10 +599,7 @@ struct WmmaGemm
__host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{
return make_tuple(I1,
I1,
Number<wmma_instr.num_acc_vgprs_per_wave>{},
Number<wmma_instr.acc_pack_number>{});
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment