gemm.cpp 3.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp>
#include <blaze/math/CustomMatrix.h>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace ref {

template <class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT

template <class T>
static auto make_mat(tensor_view<T> x)
{
    const auto& s = x.get_shape();
    // assert(s.lens().size() == 2);
19
20
21
    int n_dims = s.lens().size();
    int dim_0  = n_dims - 2;
    int dim_1  = n_dims - 1;
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    if(s.transposed())
        return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
    return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
}

template <class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
    auto mat = make_mat(x);
    if(x.get_shape().transposed())
        f(blaze::trans(mat));
    else
        f(mat);
}

template <class T>
struct is_fast_gemm_type : std::false_type
{
};

template <>
struct is_fast_gemm_type<float> : std::true_type
{
};

template <class T, class F>
void migemm_impl(
    tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::true_type)
{
    visit_mat(amat, [&](const auto& a) {
        visit_mat(bmat, [&](const auto& b) {
            auto c = make_mat(cmat);
            c      = beta * c;
            // This is a simple optimization to avoid
            // compute A * B if alpha is 0.0
            if(alpha != 0.0)
            {
                c = c + alpha * a * b;
            }
        });
    });
}

template <class T, class F>
void migemm_impl(
    tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type)
{
69
70
71
    int n_dims = cmat.get_shape().lens().size();
    int dim_0  = n_dims - 2;
    int dim_1  = n_dims - 1;
Shucai Xiao's avatar
Shucai Xiao committed
72
    auto k     = amat.get_shape().lens()[dim_1];
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
    assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
    assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);

    shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
        auto a_idx = c_idx;
        auto b_idx = c_idx;
        double s   = 0.0;
        dfor(k)([&](auto kk) {
            a_idx[dim_1] = b_idx[dim_0] = kk;
            s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
        });
        cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
    });
}

template <class T, class F>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
    auto lens = amat.get_shape().lens();
    bool batch_mul =
Shucai Xiao's avatar
Shucai Xiao committed
95
        std::accumulate(lens.rbegin() + 2, lens.rend(), int{1}, std::multiplies<int>()) == 1;
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    if(batch_mul)
    {
        migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
    }
    else
    {
        migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{});
    }
}

template <class F>
void migemm_tpl(
    const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{
    visit_all(c_arg, a_arg, b_arg)(
        [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
}

void migemm(
    const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
    migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}

void migemm(const argument& c_arg,
            const argument& a_arg,
            const argument& b_arg,
            int32_t alpha,
            int32_t beta)
{
    migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}

} // namespace ref
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx