Commit fc148cef authored by Chao Liu's avatar Chao Liu
Browse files

added back pipelined 2x2 to blockwise gemm

parent 0374f8de
......@@ -721,7 +721,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1<BlockSize,
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
......
......@@ -151,11 +151,27 @@ template <typename FloatA,
typename ADesc,
typename BDesc,
typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
......@@ -169,10 +185,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
CBuffer& c_buf,
COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
......@@ -192,11 +204,11 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto K = ADesc{}.GetLength(I0);
constexpr auto M0 = CDesc{}.GetLength(I0);
constexpr auto M1 = CDesc{}.GetLength(I1);
constexpr auto N0 = CDesc{}.GetLength(I2);
constexpr auto N1 = CDesc{}.GetLength(I3);
constexpr auto K = KLengths{}[I0];
constexpr auto M0 = MLengths{}[I0];
constexpr auto M1 = MLengths{}[I1];
constexpr auto N0 = NLengths{}[I0];
constexpr auto N1 = NLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment