gemm.cpp 5.4 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
32
template <class... Ts>
Paul's avatar
Paul committed
33
34
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
35
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
36
37
}

Paul's avatar
Paul committed
38
template <class... Ts>
Paul's avatar
Paul committed
39
40
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
41
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
42
43
}

Paul's avatar
Paul committed
44
template <class... Ts>
Paul's avatar
Paul committed
45
46
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
47
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
48
49
}

Paul's avatar
Paul committed
50
template <class T, class... Ts>
Paul's avatar
Paul committed
51
52
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
53
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
54
55
}

Paul's avatar
Paul committed
56
template <class T>
Paul's avatar
Paul committed
57
58
59
60
61
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
62
template <class T>
Paul's avatar
Paul committed
63
64
65
66
67
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
68
template <>
Paul's avatar
Paul committed
69
70
71
72
73
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
74
template <class T>
Paul's avatar
Paul committed
75
76
77
78
79
80
81
82
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
83
template <class T>
Paul's avatar
Paul committed
84
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
85
{
Paul's avatar
Paul committed
86
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
87
88
}

Paul's avatar
Paul committed
89
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
90

wsttiger's avatar
wsttiger committed
91
92
93
94
95
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
    check_shapes{inputs, *this}.has(3);
    return op.compute_shape({inputs.at(0), inputs.at(1)});
}
wsttiger's avatar
wsttiger committed
96
97
98
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
99
{
Shucai Xiao's avatar
Shucai Xiao committed
100
101
102
103
    float alpha        = 1.0f;
    float beta         = 0.0f;
    bool transa        = args[0].get_shape().transposed();
    bool transb        = args[1].get_shape().transposed();
104
    std::size_t n_dims = args[0].get_shape().lens().size();
Shucai Xiao's avatar
Shucai Xiao committed
105
106
107
108
109
    std::size_t dim_0  = n_dims - 2;
    std::size_t dim_1  = n_dims - 1;
    rocblas_int lda    = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb    = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
    rocblas_int ldc    = args[2].get_shape().strides()[dim_0];
Shucai Xiao's avatar
Shucai Xiao committed
110
    auto out_lens      = output_shape.lens();
111
112
    rocblas_int m      = out_lens[dim_0];
    rocblas_int n      = out_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
113
    rocblas_int k      = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
114
115
    auto batch_num     = std::accumulate(
        out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
Paul's avatar
Paul committed
116
    output_shape.visit_type([&](auto as) {
Paul's avatar
Paul committed
117
118
119
        auto alpha_r    = to_rocblas_type(as(alpha));
        auto beta_r     = to_rocblas_type(as(beta));
        auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        // call the strided implementation only if there are multiple matrices
        if (batch_num > 1)
        {
            generic_rocblas_batched_gemm(as,
                                        ctx.get_stream().get_rocblas(),
                                        transb ? rocblas_operation_transpose : rocblas_operation_none,
                                        transa ? rocblas_operation_transpose : rocblas_operation_none,
                                        n,
                                        m,
                                        k,
                                        &alpha_r,
                                        to_pointer(args[1]),
                                        ldb,
                                        k * n,
                                        to_pointer(args[0]),
                                        lda,
                                        m * k,
                                        &beta_r,
                                        to_pointer(args[2]),
                                        ldc,
                                        m * n,
                                        batch_num);
        }
        else
        {
            generic_rocblas_gemm(as,
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 n,
                                 m,
                                 k,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
                                 to_pointer(args[0]),
                                 lda,
                                 &beta_r,
                                 to_pointer(args[2]),
                                 ldc);
        }
Paul's avatar
Paul committed
161
    });
162

wsttiger's avatar
wsttiger committed
163
164
165
166
    return args[2];
}

} // namespace gpu
Paul's avatar
Paul committed
167
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
168
} // namespace migraphx