Commit 4c102fcc authored by aska-0096's avatar aska-0096
Browse files

Solve a bug when K1=16

parent 18d5297b
...@@ -302,13 +302,13 @@ struct BlockwiseGemmWMMA ...@@ -302,13 +302,13 @@ struct BlockwiseGemmWMMA
// basic intrinsic to determine loopover direction // basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat) if constexpr(MRepeat < NRepeat)
{ {
static_for<0, KPerBlock / WmmaK, 1>{}( static_for<0, KPerBlock / KPack, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
...@@ -318,16 +318,16 @@ struct BlockwiseGemmWMMA ...@@ -318,16 +318,16 @@ struct BlockwiseGemmWMMA
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0), make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, KPack> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) = a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow, make_tuple(i / A_K1 / A_KRow,
...@@ -353,8 +353,8 @@ struct BlockwiseGemmWMMA ...@@ -353,8 +353,8 @@ struct BlockwiseGemmWMMA
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
...@@ -364,12 +364,12 @@ struct BlockwiseGemmWMMA ...@@ -364,12 +364,12 @@ struct BlockwiseGemmWMMA
{ {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, .. // k=0,kpack*1, ..
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0), make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
...@@ -377,16 +377,16 @@ struct BlockwiseGemmWMMA ...@@ -377,16 +377,16 @@ struct BlockwiseGemmWMMA
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, KPack> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) = b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow, make_tuple(i / B_K1 / B_KRow,
...@@ -412,8 +412,8 @@ struct BlockwiseGemmWMMA ...@@ -412,8 +412,8 @@ struct BlockwiseGemmWMMA
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
...@@ -423,28 +423,28 @@ struct BlockwiseGemmWMMA ...@@ -423,28 +423,28 @@ struct BlockwiseGemmWMMA
protected: protected:
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / A_K1 / A_KRow>{}, make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{},
Number<MRepeat>{}, Number<MRepeat>{},
I1, I1,
Number<A_KRow>{}, Number<A_KRow>{},
I1, I1,
Number<A_K1>{}), Number<A_K1>{}),
make_tuple(Number<A_K1 * A_KRow>{}, make_tuple(Number<A_K1 * A_KRow>{},
Number<WmmaK>{}, Number<KPack>{},
Number<A_K1 * A_KRow>{}, Number<A_K1 * A_KRow>{},
Number<A_K1>{}, Number<A_K1>{},
Number<A_K1>{}, Number<A_K1>{},
Number<1>{})); Number<1>{}));
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{}, make_naive_tensor_descriptor(make_tuple(Number<KPack / B_K1 / B_KRow>{},
Number<NRepeat>{}, Number<NRepeat>{},
I1, I1,
Number<B_KRow>{}, Number<B_KRow>{},
I1, I1,
Number<B_K1>{}), Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{}, make_tuple(Number<B_K1 * B_KRow>{},
Number<WmmaK>{}, Number<KPack>{},
Number<B_K1 * B_KRow>{}, Number<B_K1 * B_KRow>{},
Number<B_K1>{}, Number<B_K1>{},
Number<B_K1>{}, Number<B_K1>{},
...@@ -465,7 +465,7 @@ struct BlockwiseGemmWMMA ...@@ -465,7 +465,7 @@ struct BlockwiseGemmWMMA
FloatA, FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>, Sequence<KPack / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
A_K1, A_K1,
...@@ -481,7 +481,7 @@ struct BlockwiseGemmWMMA ...@@ -481,7 +481,7 @@ struct BlockwiseGemmWMMA
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>, Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
A_K1, A_K1,
...@@ -501,7 +501,7 @@ struct BlockwiseGemmWMMA ...@@ -501,7 +501,7 @@ struct BlockwiseGemmWMMA
FloatB, FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>, Sequence<KPack / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1, B_K1,
...@@ -517,7 +517,7 @@ struct BlockwiseGemmWMMA ...@@ -517,7 +517,7 @@ struct BlockwiseGemmWMMA
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>, Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1, B_K1,
......
...@@ -131,10 +131,12 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -131,10 +131,12 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; static constexpr auto AEnableLds_auto =
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; (NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
......
...@@ -89,10 +89,12 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -89,10 +89,12 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; static constexpr auto AEnableLds_auto =
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; (NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
// LDS bypass feature not implemented for dequantization pipeline. // LDS bypass feature not implemented for dequantization pipeline.
......
...@@ -93,10 +93,12 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -93,10 +93,12 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; static constexpr auto AEnableLds_auto =
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; (NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
......
...@@ -86,10 +86,12 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -86,10 +86,12 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; static constexpr auto AEnableLds_auto =
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; (NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
......
...@@ -148,7 +148,7 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -148,7 +148,7 @@ struct GridwiseFpAintBGemm_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
...@@ -340,7 +340,7 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -340,7 +340,7 @@ struct GridwiseGemmMultipleD_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
...@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma ...@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
...@@ -373,7 +373,7 @@ struct WmmaGemm ...@@ -373,7 +373,7 @@ struct WmmaGemm
static_assert(NPerWmma == 16 && MPerWmma == 16, static_assert(NPerWmma == 16 && MPerWmma == 16,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
} }
// WMMA output supporting C = A * B // WMMA output supporting C = A * B
...@@ -486,14 +486,16 @@ struct WmmaGemm ...@@ -486,14 +486,16 @@ struct WmmaGemm
, ,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!"); "(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC) static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
{ if constexpr(!TransposeC)
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread); {
} wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
else }
{ else
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread); {
} wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
}
});
} }
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment