"include/vscode:/vscode.git/clone" did not exist on "ecdfe960921032c1aae6dc2c4a3e0ad1b8bba559"
Commit 4c102fcc authored by aska-0096's avatar aska-0096
Browse files

Solve a bug when K1=16

parent 18d5297b
......@@ -302,13 +302,13 @@ struct BlockwiseGemmWMMA
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, KPerBlock / WmmaK, 1>{}(
static_for<0, KPerBlock / KPack, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
......@@ -318,16 +318,16 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
......@@ -353,8 +353,8 @@ struct BlockwiseGemmWMMA
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
......@@ -364,12 +364,12 @@ struct BlockwiseGemmWMMA
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of
static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
......@@ -377,16 +377,16 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
static_for<0, KPack, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
......@@ -412,8 +412,8 @@ struct BlockwiseGemmWMMA
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
......@@ -423,28 +423,28 @@ struct BlockwiseGemmWMMA
protected:
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / A_K1 / A_KRow>{},
make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1 * A_KRow>{},
Number<WmmaK>{},
Number<KPack>{},
Number<A_K1 * A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{},
make_naive_tensor_descriptor(make_tuple(Number<KPack / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<WmmaK>{},
Number<KPack>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
......@@ -465,7 +465,7 @@ struct BlockwiseGemmWMMA
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
Sequence<KPack / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
......@@ -481,7 +481,7 @@ struct BlockwiseGemmWMMA
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
......@@ -501,7 +501,7 @@ struct BlockwiseGemmWMMA
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<KPack / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
......@@ -517,7 +517,7 @@ struct BlockwiseGemmWMMA
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
......
......@@ -131,10 +131,12 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
......
......@@ -89,10 +89,12 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally
// LDS bypass feature not implemented for dequantization pipeline.
......
......@@ -93,10 +93,12 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
......
......@@ -86,10 +86,12 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
......
......@@ -148,7 +148,7 @@ struct GridwiseFpAintBGemm_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
......@@ -340,7 +340,7 @@ struct GridwiseGemmMultipleD_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
......@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
......@@ -373,7 +373,7 @@ struct WmmaGemm
static_assert(NPerWmma == 16 && MPerWmma == 16,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
}
// WMMA output supporting C = A * B
......@@ -486,14 +486,16 @@ struct WmmaGemm
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
}
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
}
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment