gemm.cpp 10.3 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
    rocblas_sscal(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
    rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    rocblas_haxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
    rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    rocblas_daxpy(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
110
template <class... Ts>
Paul's avatar
Paul committed
111
112
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
113
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
114
115
}

Paul's avatar
Paul committed
116
template <class... Ts>
Paul's avatar
Paul committed
117
118
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
119
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
120
121
}

Paul's avatar
Paul committed
122
template <class... Ts>
Paul's avatar
Paul committed
123
124
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
125
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
template <class T, class... Ts>
Paul's avatar
Paul committed
129
130
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
131
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
template <class T>
Paul's avatar
Paul committed
135
136
137
138
139
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
146
template <>
Paul's avatar
Paul committed
147
148
149
150
151
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
152
template <class T>
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
161
template <class T>
Paul's avatar
Paul committed
162
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
163
{
Paul's avatar
Paul committed
164
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
165
166
}

Paul's avatar
Paul committed
167
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
168

wsttiger's avatar
wsttiger committed
169
170
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
171
    return op.compute_shape(inputs);
wsttiger's avatar
wsttiger committed
172
}
173

Shucai Xiao's avatar
Shucai Xiao committed
174
175
176
std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens,
                                        std::size_t index,
                                        std::vector<std::size_t>& data_lens) const
177
178
179
{
}

wsttiger's avatar
wsttiger committed
180
181
182
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
183
{
184
185
    bool is_3inputs = (args.size() == 4);

Shucai Xiao's avatar
Shucai Xiao committed
186
    if(output_shape.elements() == 1)
187
188
189
190
191
    {
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(op.beta));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
192
193
194
            generic_rocblas_dot(as,
                                ctx.get_stream().get_rocblas(),
                                args[1].get_shape().elements(),
195
196
197
198
                                to_pointer(args[0]),
                                1,
                                to_pointer(args[1]),
                                1,
Shucai Xiao's avatar
Shucai Xiao committed
199
                                is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
200

Shucai Xiao's avatar
Shucai Xiao committed
201
202
203
204
205
            generic_rocblas_scal(as,
                                 ctx.get_stream().get_rocblas(),
                                 1,
                                 &alpha_r,
                                 is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]));
206
207
                                    1);

Shucai Xiao's avatar
Shucai Xiao committed
208
209
210
211
212
213
214
215
216
217
218
219
                                    if(is_3inputs)
                                    {

                                        generic_rocblas_axpy(as,
                                                             ctx.get_stream().get_rocblas(),
                                                             1,
                                                             &beta_r,
                                                             to_pointer(args[2]),
                                                             1,
                                                             to_pointer(args[3]),
                                                             1);
                                    }
220
221
222
223
224
225
        });

        return is_3inputs ? args[3] : args[2];
    }

    // b is a vector, so the computation is matrix * vector
Shucai Xiao's avatar
Shucai Xiao committed
226
    // could not be the case of inner product of vectors since
227
    // it is already processed above
Shucai Xiao's avatar
Shucai Xiao committed
228
    if(args[1].get_shape().lens().size() == 1)
229
230
231
    {
        // considering the batch input, so A could be a batch
        // of matrices
Shucai Xiao's avatar
Shucai Xiao committed
232
        auto a_lens        = args[0].get_shape().lens();
233
234
235
236
237
238
239
240
241
        std::size_t n_dims = a_lens.size();
        std::size_t dim_0  = n_dims - 2;
        std::size_t dim_1  = n_dims - 1;
        bool transa        = args[0].get_shape().transposed();
        rocblas_int lda    = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
        rocblas_int m      = a_lens[dim_0];
        rocblas_int k      = a_lens[dim_1];
        auto batch_num     = std::accumulate(
            a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
Shucai Xiao's avatar
Shucai Xiao committed
242

243
244
245
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(op.beta));
Shucai Xiao's avatar
Shucai Xiao committed
246
247
248
249
            auto to_pointer = [&](auto&& arg, std::size_t offset) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };
            for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
250
251
252
            {
                if(is_3inputs)
                    hipMemcpy(to_pointer(args[3] + batch_no * m),
Shucai Xiao's avatar
Shucai Xiao committed
253
254
255
                              to_pointer(args[2]),
                              output_shape.bytes(),
                              hipMemcpyDeviceToDevice);
256
257
258
259
260
261
                else
                    hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
262
263
    bool transa        = args[0].get_shape().transposed();
    bool transb        = args[1].get_shape().transposed();
264
    std::size_t n_dims = args[0].get_shape().lens().size();
Shucai Xiao's avatar
Shucai Xiao committed
265
266
267
268
269
    std::size_t dim_0  = n_dims - 2;
    std::size_t dim_1  = n_dims - 1;
    rocblas_int lda    = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb    = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
    rocblas_int ldc    = args[2].get_shape().strides()[dim_0];
Shucai Xiao's avatar
Shucai Xiao committed
270
    auto out_lens      = output_shape.lens();
271
272
    rocblas_int m      = out_lens[dim_0];
    rocblas_int n      = out_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
273
    rocblas_int k      = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
274
275
    auto batch_num     = std::accumulate(
        out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
276
277

    bool is_3inputs = (args.size() == 4);
278
    // two input arguments
Shucai Xiao's avatar
Shucai Xiao committed
279
    if(!is_3inputs)
280
281
282
    {
    }

283
    output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
284
285
286
        auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
            return to_rocblas_type(as.from(arg.data() + offset));
        };
Shucai Xiao's avatar
Shucai Xiao committed
287
288
289
290
291
        if(is_3inputs)
            hipMemcpy(to_pointer(args[3]),
                      to_pointer(args[2]),
                      output_shape.bytes(),
                      hipMemcpyDeviceToDevice);
292
293
294
295
        else
            hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
    });

Paul's avatar
Paul committed
296
    output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
297
298
299
        auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
            return to_rocblas_type(as.from(arg.data() + offset));
        };
Shucai Xiao's avatar
Shucai Xiao committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        generic_rocblas_batched_gemm(as,
                                     ctx.get_stream().get_rocblas(),
                                     transb ? rocblas_operation_transpose : rocblas_operation_none,
                                     transa ? rocblas_operation_transpose : rocblas_operation_none,
                                     n,
                                     m,
                                     k,
                                     &alpha_r,
                                     to_pointer(args[1]),
                                     ldb,
                                     k * n,
                                     to_pointer(args[0]),
                                     lda,
                                     m * k,
                                     &beta_r,
                                     to_pointer(args[2]),
                                     ldc,
                                     m * n,
                                     batch_num);
Paul's avatar
Paul committed
319
    });
320

321
    return (is_3inputs ? args[3] : args[2]);
wsttiger's avatar
wsttiger committed
322
323
324
}

} // namespace gpu
Paul's avatar
Paul committed
325
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
326
} // namespace migraphx