gemm.cpp 2 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <migraph/cpu/gemm.hpp>
#include <migraph/requires.hpp>
#include <blaze/math/CustomMatrix.h>

namespace migraph {
namespace cpu {

template <class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT

template <class T>
static auto make_mat(tensor_view<T> x)
{
    const auto& s = x.get_shape();
    assert(s.lens().size() == 2);
    assert(s.packed());
    if(s.transposed())
        return matrix<T>{x.data(), s.lens()[1], s.lens()[0]};
    return matrix<T>{x.data(), s.lens()[0], s.lens()[1]};
}

template <class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
    auto mat = make_mat(x);
    if(x.get_shape().transposed())
        f(blaze::trans(mat));
    else
        f(mat);
}

template<class T>
struct is_fast_gemm_type
: std::false_type
{};

template<>
struct is_fast_gemm_type<float>
: std::true_type
{};

template<class T>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta, std::true_type)
{
    visit_mat(amat, [&](const auto& a) {
        visit_mat(bmat, [&](const auto& b) {
            auto c = make_mat(cmat);
            if(alpha == 1.0 and beta == 0.0)
                c = a * b;
            else
                c = (a * b) * alpha + beta * c;
        });
    });
}

template<class T>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta, std::false_type)
{
    (void)cmat;
    (void)amat;
    (void)bmat;
    (void)alpha;
    (void)beta;
    assert(true && "TODO");
}

template<class T>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{
    migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}

void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
    visit_all(c_arg, a_arg, b_arg)([&](auto cmat, auto amat, auto bmat) {
            migemm_impl(cmat, amat, bmat, alpha, beta);
        });
}

} // namespace cpu

} // namespace migraph