gemm.cpp 4.14 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/cpu/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
4
#include <migraphx/shape_for_each.hpp>
Paul's avatar
Paul committed
5
6
#include <blaze/math/CustomMatrix.h>

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
9
10
11
12
13
14
15
16
17
namespace cpu {

template <class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT

template <class T>
static auto make_mat(tensor_view<T> x)
{
    const auto& s = x.get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
18
    // assert(s.lens().size() == 2);
19
    std::size_t n_dims = s.lens().size();
Shucai Xiao's avatar
Shucai Xiao committed
20
21
    std::size_t dim_0  = n_dims - 2;
    std::size_t dim_1  = n_dims - 1;
Paul's avatar
Paul committed
22
    if(s.transposed())
23
24
        return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
    return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
Paul's avatar
Paul committed
25
26
27
28
29
30
31
32
33
34
35
36
}

template <class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
    auto mat = make_mat(x);
    if(x.get_shape().transposed())
        f(blaze::trans(mat));
    else
        f(mat);
}

Paul's avatar
Paul committed
37
38
39
40
template <class T>
struct is_fast_gemm_type : std::false_type
{
};
Paul's avatar
Paul committed
41

Paul's avatar
Paul committed
42
43
44
45
template <>
struct is_fast_gemm_type<float> : std::true_type
{
};
Paul's avatar
Paul committed
46

Paul's avatar
Paul committed
47
48
49
50
51
52
53
template <class T>
void migemm_impl(tensor_view<T> cmat,
                 tensor_view<T> amat,
                 tensor_view<T> bmat,
                 float alpha,
                 float beta,
                 std::true_type)
Paul's avatar
Paul committed
54
55
56
57
{
    visit_mat(amat, [&](const auto& a) {
        visit_mat(bmat, [&](const auto& b) {
            auto c = make_mat(cmat);
Shucai Xiao's avatar
Shucai Xiao committed
58
            c      = beta * c;
59

Shucai Xiao's avatar
Shucai Xiao committed
60
            if(alpha != 0.0)
61
62
63
            {
                c = c + alpha * a * b;
            }
Paul's avatar
Paul committed
64
65
66
67
        });
    });
}

Paul's avatar
Paul committed
68
69
70
71
72
73
74
template <class T>
void migemm_impl(tensor_view<T> cmat,
                 tensor_view<T> amat,
                 tensor_view<T> bmat,
                 float alpha,
                 float beta,
                 std::false_type)
Paul's avatar
Paul committed
75
{
76
77
78
79
    auto a_lens = amat.get_shape().lens();
    auto b_lens = bmat.get_shape().lens();
    auto c_lens = cmat.get_shape().lens();

80
81
82
83
    std::size_t nc_dims = c_lens.size();
    std::size_t na_dims = a_lens.size();
    std::size_t nb_dims = b_lens.size();
    auto k             = a_lens[na_dims - 1];
84

85
86
87
    assert(a_lens[na_dims - 1] == b_lens[nb_dims - 1]);
    assert(c_lens[nc_dims - 2] == a_lens[na_dims - 2]);
    assert(c_lens[nc_dims - 1] == b_lens[nb_dims - 1]);
Paul's avatar
Paul committed
88

89
90
91
92
    std::size_t a_len_diff = nc_dims - na_dims;
    std::size_t b_len_diff = nc_dims - nb_dims;
    std::vector<std::size_t> a_idx(na_dims);
    std::vector<std::size_t> b_idx(nb_dims);
Paul's avatar
Paul committed
93

94
    shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
Shucai Xiao's avatar
Shucai Xiao committed
95
96
97
98
99
100
101
102
103
104
105
106
        std::transform(c_lens.begin() + a_len_diff,
                       c_lens.end(),
                       a_lens.begin(),
                       a_idx.begin(),
                       [&](auto i, auto j) { return (j == 1) ? 0 : i; });
        std::transform(c_lens.begin() + b_len_diff,
                       c_lens.end(),
                       b_lens.begin(),
                       b_idx.begin(),
                       [&](auto i, auto j) { return (j == 1) ? 0 : i; });

        double s = 0.0;
107
        dfor(k)([&](auto kk) {
108
            a_idx[na_dims - 1] = b_idx[nb_dims - 2] = kk;
Shucai Xiao's avatar
Shucai Xiao committed
109
            s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
110
        });
111
        cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
Paul's avatar
Paul committed
112
    });
Paul's avatar
Paul committed
113
114
}

Paul's avatar
Paul committed
115
116
117
template <class T>
void migemm_impl(
    tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
Paul's avatar
Paul committed
118
{
Shucai Xiao's avatar
Shucai Xiao committed
119
120
121
    auto lens                = cmat.get_shape().lens();
    std::size_t num_matrices = std::accumulate(
        lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
122
    if(num_matrices == 1)
123
124
125
126
127
128
129
    {
        migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
    }
    else
    {
        migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{});
    }
Paul's avatar
Paul committed
130
131
}

Paul's avatar
Paul committed
132
133
void migemm(
    const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
Paul's avatar
Paul committed
134
{
Paul's avatar
Paul committed
135
136
    visit_all(c_arg, a_arg, b_arg)(
        [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
Paul's avatar
Paul committed
137
138
139
}

} // namespace cpu
Paul's avatar
Paul committed
140
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
141
} // namespace migraphx