Commit 5b1e442e authored by Paul's avatar Paul
Browse files

Try to speed up compilation

parent e1e37208
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
add_library(migraph_cpu add_library(migraph_cpu
cpu_target.cpp cpu_target.cpp
cpu_lowering.cpp cpu_lowering.cpp
gemm.cpp
) )
find_path(BLAZE_INCLUDE blaze/Blaze.h) find_path(BLAZE_INCLUDE blaze/Blaze.h)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/shape_for_each.hpp> #include <migraph/shape_for_each.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <blaze/Blaze.h> #include <migraph/cpu/gemm.hpp>
#include <unordered_map> #include <unordered_map>
namespace migraph { namespace migraph {
...@@ -227,44 +227,10 @@ struct cpu_gemm ...@@ -227,44 +227,10 @@ struct cpu_gemm
std::string name() const { return "cpu::gemm"; } std::string name() const { return "cpu::gemm"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
template <class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT
template <class T>
static auto make_mat(tensor_view<T> x)
{
const auto& s = x.get_shape();
assert(s.lens().size() == 2);
assert(s.packed());
if(s.transposed())
return matrix<T>{x.data(), s.lens()[1], s.lens()[0]};
return matrix<T>{x.data(), s.lens()[0], s.lens()[1]};
}
template <class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
auto mat = make_mat(x);
if(x.get_shape().transposed())
f(blaze::trans(mat));
else
f(mat);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto cmat, auto amat, auto bmat) { migemm(result, args[0], args[1], op.alpha, op.beta);
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
if(op.alpha == 1.0 and op.beta == 0.0)
c = a * b;
else
c = (a * b) * op.alpha + op.beta * c;
});
});
});
return result; return result;
} }
}; };
......
#include <migraph/cpu/gemm.hpp>
#include <migraph/requires.hpp>
#include <blaze/math/CustomMatrix.h>
namespace migraph {
namespace cpu {
template <class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT
template <class T>
static auto make_mat(tensor_view<T> x)
{
const auto& s = x.get_shape();
assert(s.lens().size() == 2);
assert(s.packed());
if(s.transposed())
return matrix<T>{x.data(), s.lens()[1], s.lens()[0]};
return matrix<T>{x.data(), s.lens()[0], s.lens()[1]};
}
template <class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
auto mat = make_mat(x);
if(x.get_shape().transposed())
f(blaze::trans(mat));
else
f(mat);
}
template<class T>
struct is_fast_gemm_type
: std::false_type
{};
template<>
struct is_fast_gemm_type<float>
: std::true_type
{};
template<class T>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta, std::true_type)
{
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
if(alpha == 1.0 and beta == 0.0)
c = a * b;
else
c = (a * b) * alpha + beta * c;
});
});
}
template<class T>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta, std::false_type)
{
(void)cmat;
(void)amat;
(void)bmat;
(void)alpha;
(void)beta;
assert(true && "TODO");
}
template<class T>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
visit_all(c_arg, a_arg, b_arg)([&](auto cmat, auto amat, auto bmat) {
migemm_impl(cmat, amat, bmat, alpha, beta);
});
}
} // namespace cpu
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP
#include <migraph/argument.hpp>
namespace migraph {
namespace cpu {
void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
} // namespace cpu
} // namespace migraph
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment