Commit 0b5fa390 authored by Paul's avatar Paul
Browse files

Formatting

parent 5b1e442e
...@@ -29,18 +29,23 @@ static void visit_mat(tensor_view<T> x, F f) ...@@ -29,18 +29,23 @@ static void visit_mat(tensor_view<T> x, F f)
f(mat); f(mat);
} }
template<class T> template <class T>
struct is_fast_gemm_type struct is_fast_gemm_type : std::false_type
: std::false_type {
{}; };
template<> template <>
struct is_fast_gemm_type<float> struct is_fast_gemm_type<float> : std::true_type
: std::true_type {
{}; };
template<class T> template <class T>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta, std::true_type) void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
float alpha,
float beta,
std::true_type)
{ {
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
...@@ -53,8 +58,13 @@ void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, ...@@ -53,8 +58,13 @@ void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat,
}); });
} }
template<class T> template <class T>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta, std::false_type) void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
float alpha,
float beta,
std::false_type)
{ {
(void)cmat; (void)cmat;
(void)amat; (void)amat;
...@@ -64,17 +74,18 @@ void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, ...@@ -64,17 +74,18 @@ void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat,
assert(true && "TODO"); assert(true && "TODO");
} }
template<class T> template <class T>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta) void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{ {
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{}); migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
} }
void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta) void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{ {
visit_all(c_arg, a_arg, b_arg)([&](auto cmat, auto amat, auto bmat) { visit_all(c_arg, a_arg, b_arg)(
migemm_impl(cmat, amat, bmat, alpha, beta); [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
});
} }
} // namespace cpu } // namespace cpu
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
namespace migraph { namespace migraph {
namespace cpu { namespace cpu {
void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta); void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
} // namespace cpu } // namespace cpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment