"vscode:/vscode.git/clone" did not exist on "1e2ef8faeb8fc7403e82cdc6212ee6bfe28ea8e4"
gemm.cpp 2.55 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/cpu/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
Paul's avatar
Paul committed
4
5
#include <blaze/math/CustomMatrix.h>

Paul's avatar
Paul committed
6
namespace migraphx {
7
inline namespace MIGRAPH_INLINE_NS {
Paul's avatar
Paul committed
8
9
10
11
12
13
14
15
16
17
18
namespace cpu {

template <class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT

template <class T>
static auto make_mat(tensor_view<T> x)
{
    const auto& s = x.get_shape();
    assert(s.lens().size() == 2);
    if(s.transposed())
Paul's avatar
Paul committed
19
20
        return matrix<T>{x.data(), s.lens()[1], s.lens()[0], s.strides()[1]};
    return matrix<T>{x.data(), s.lens()[0], s.lens()[1], s.strides()[0]};
Paul's avatar
Paul committed
21
22
23
24
25
26
27
28
29
30
31
32
}

template <class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
    auto mat = make_mat(x);
    if(x.get_shape().transposed())
        f(blaze::trans(mat));
    else
        f(mat);
}

Paul's avatar
Paul committed
33
34
35
36
template <class T>
struct is_fast_gemm_type : std::false_type
{
};
Paul's avatar
Paul committed
37

Paul's avatar
Paul committed
38
39
40
41
template <>
struct is_fast_gemm_type<float> : std::true_type
{
};
Paul's avatar
Paul committed
42

Paul's avatar
Paul committed
43
44
45
46
47
48
49
template <class T>
void migemm_impl(tensor_view<T> cmat,
                 tensor_view<T> amat,
                 tensor_view<T> bmat,
                 float alpha,
                 float beta,
                 std::true_type)
Paul's avatar
Paul committed
50
51
52
53
{
    visit_mat(amat, [&](const auto& a) {
        visit_mat(bmat, [&](const auto& b) {
            auto c = make_mat(cmat);
Paul's avatar
Paul committed
54
            c      = (a * b) * alpha + beta * c;
Paul's avatar
Paul committed
55
56
57
58
        });
    });
}

Paul's avatar
Paul committed
59
60
61
62
63
64
65
template <class T>
void migemm_impl(tensor_view<T> cmat,
                 tensor_view<T> amat,
                 tensor_view<T> bmat,
                 float alpha,
                 float beta,
                 std::false_type)
Paul's avatar
Paul committed
66
{
Paul's avatar
Paul committed
67
68
69
70
71
72
73
74
    auto m = cmat.get_shape().lens()[0];
    auto n = cmat.get_shape().lens()[1];
    auto k = amat.get_shape().lens()[1];

    assert(amat.get_shape().lens()[1] == bmat.get_shape().lens()[0]);
    assert(m == amat.get_shape().lens()[0]);
    assert(n == bmat.get_shape().lens()[1]);

Paul's avatar
Paul committed
75
    dfor(m, n)([&](auto ii, auto jj) {
Paul's avatar
Paul committed
76
        double s = cmat(ii, jj) * beta;
Paul's avatar
Paul committed
77
        dfor(k)([&](auto kk) { s += amat(ii, kk) * bmat(kk, jj); });
Paul's avatar
Paul committed
78
79
        cmat(ii, jj) = alpha * s;
    });
Paul's avatar
Paul committed
80
81
}

Paul's avatar
Paul committed
82
83
84
template <class T>
void migemm_impl(
    tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
Paul's avatar
Paul committed
85
86
87
88
{
    migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}

Paul's avatar
Paul committed
89
90
void migemm(
    const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
Paul's avatar
Paul committed
91
{
Paul's avatar
Paul committed
92
93
    visit_all(c_arg, a_arg, b_arg)(
        [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
Paul's avatar
Paul committed
94
95
96
}

} // namespace cpu
97
} // namespace MIGRAPH_INLINE_NS
Paul's avatar
Paul committed
98
} // namespace migraphx