Commit fc148cef authored by Chao Liu's avatar Chao Liu
Browse files

added back pipelined 2x2 to blockwise gemm

parent 0374f8de
...@@ -721,22 +721,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -721,22 +721,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1<BlockSize, BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_k_m0_m1_block_desc), decltype(a_k_m0_m1_block_desc),
decltype(b_k_n0_n1_block_desc), decltype(b_k_n0_n1_block_desc),
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
MPerThread, MPerThread,
NPerThread, NPerThread,
KPerThread, KPerThread,
MLevel0Cluster, MLevel0Cluster,
NLevel0Cluster, NLevel0Cluster,
MLevel1Cluster, MLevel1Cluster,
NLevel1Cluster, NLevel1Cluster,
MPerThread, MPerThread,
NPerThread>{}; NPerThread>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
......
...@@ -151,11 +151,27 @@ template <typename FloatA, ...@@ -151,11 +151,27 @@ template <typename FloatA,
typename ADesc, typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer, template <typename ABuffer,
typename AOriginIdx, typename AOriginIdx,
typename BBuffer, typename BBuffer,
...@@ -169,10 +185,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -169,10 +185,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
...@@ -192,11 +204,11 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -192,11 +204,11 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto K = ADesc{}.GetLength(I0); constexpr auto K = KLengths{}[I0];
constexpr auto M0 = CDesc{}.GetLength(I0); constexpr auto M0 = MLengths{}[I0];
constexpr auto M1 = CDesc{}.GetLength(I1); constexpr auto M1 = MLengths{}[I1];
constexpr auto N0 = CDesc{}.GetLength(I2); constexpr auto N0 = NLengths{}[I0];
constexpr auto N1 = CDesc{}.GetLength(I3); constexpr auto N1 = NLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment