"vscode:/vscode.git/clone" did not exist on "a09dc50215f91c5ec2c0d6b2afd9fa0ed45f931c"
gemm.cpp 17.1 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
    rocblas_sscal(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
    rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    rocblas_haxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
    rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    rocblas_daxpy(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
110
template <class... Ts>
Paul's avatar
Paul committed
111
112
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
113
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
114
115
}

Paul's avatar
Paul committed
116
template <class... Ts>
Paul's avatar
Paul committed
117
118
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
119
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
120
121
}

Paul's avatar
Paul committed
122
template <class... Ts>
Paul's avatar
Paul committed
123
124
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
125
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
template <class T, class... Ts>
Paul's avatar
Paul committed
129
130
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
131
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
template <class T>
Paul's avatar
Paul committed
135
136
137
138
139
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
146
template <>
Paul's avatar
Paul committed
147
148
149
150
151
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
152
template <class T>
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
161
template <class T>
Paul's avatar
Paul committed
162
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
163
{
Paul's avatar
Paul committed
164
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
165
166
}

Paul's avatar
Paul committed
167
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
168

wsttiger's avatar
wsttiger committed
169
170
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
171
    return op.compute_shape(inputs);
wsttiger's avatar
wsttiger committed
172
}
173

Shucai Xiao's avatar
Shucai Xiao committed
174
void miopen_gemm::fill_result(const shape& output_shape,
Shucai Xiao's avatar
Shucai Xiao committed
175
176
                              const argument& result,
                              const argument& c) const
177
{
178
    auto out_lens = output_shape.lens();
Shucai Xiao's avatar
Shucai Xiao committed
179
180
    auto c_lens   = c.get_shape().lens();
    if(output_shape == c.get_shape())
181
182
    {
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
183
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
184
185
            hipMemcpy(to_pointer(result),
                      to_pointer(c),
Shucai Xiao's avatar
Shucai Xiao committed
186
187
                      output_shape.bytes(),
                      hipMemcpyDeviceToDevice);
188
189
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
190
    else if(c.single())
191
192
    {
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
193
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
194
195
196
197
198
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
Shucai Xiao's avatar
Shucai Xiao committed
199
200
201
                hipMemcpy(to_pointer(result, i),
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
202
                          hipMemcpyDeviceToDevice);
203
204
205
            }
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
206
    else if(c_lens.size() == 1 || (c_lens.size() == 2 && c_lens[1] == out_lens[1]))
207
208
209
210
    {
        auto m = out_lens[0];
        auto n = out_lens[1];
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
211
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
212
213
214
215
216
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < m; ++i)
            {
Shucai Xiao's avatar
Shucai Xiao committed
217
218
219
                hipMemcpy(to_pointer(result, i * n),
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
220
                          hipMemcpyDeviceToDevice);
221
222
223
224
225
226
227
228
229
230
231
232
233
            }
        });
    }
    // case of c_lens.size() == 2 && c_len[0] == out_lens[0]
    else
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg, std::size_t offset) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
Shucai Xiao's avatar
Shucai Xiao committed
234
235
236
                hipMemcpy(to_pointer(result, i),
                          to_pointer(c, i / out_lens[0]),
                          c.get_shape().type_size(),
Shucai Xiao's avatar
Shucai Xiao committed
237
                          hipMemcpyDeviceToDevice);
238
239
240
            }
        });
    }
241
242
}

wsttiger's avatar
wsttiger committed
243
244
245
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
246
{
247
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
248
    if(is_3inputs)
249
    {
Shucai Xiao's avatar
Shucai Xiao committed
250
        fill_result(output_shape, args[3], args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
251

252
253
254
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(op.beta));
Shucai Xiao's avatar
Shucai Xiao committed
255
256
257
258
259
260
261
262
263
            bool transa     = args[0].get_shape().transposed();
            bool transb     = args[1].get_shape().transposed();
            rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
            rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
            rocblas_int ldc = args[2].get_shape().strides()[0];
            auto out_lens   = output_shape.lens();
            rocblas_int m   = out_lens[0];
            rocblas_int n   = out_lens[1];
            rocblas_int k   = args[0].get_shape().lens()[1];
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            generic_rocblas_gemm(as,
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 n,
                                 m,
                                 k,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
                                 to_pointer(args[0]),
                                 lda,
                                 &beta_r,
                                 to_pointer(args[2]),
                                 ldc);

        });

        return args[3];
    }
285

286
287
    // 2 input arguments cases
    // vector inner product
Shucai Xiao's avatar
Shucai Xiao committed
288
    if(output_shape.elements() == 1)
289
    {
290
        assert(args[0].get_shape().elements() == args[1].get_shape().elements());
291
292
293
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
294
295
296
            generic_rocblas_dot(as,
                                ctx.get_stream().get_rocblas(),
                                args[1].get_shape().elements(),
297
298
299
300
                                to_pointer(args[0]),
                                1,
                                to_pointer(args[1]),
                                1,
301
                                to_pointer(args[2]));
302

Shucai Xiao's avatar
Shucai Xiao committed
303
            generic_rocblas_scal(
Shucai Xiao's avatar
Shucai Xiao committed
304
                as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]), 1);
305
306
        });
    }
307
    // matrix * vector
Shucai Xiao's avatar
Shucai Xiao committed
308
    else if(args[1].get_shape().lens().size() == 1)
309
    {
Shucai Xiao's avatar
Shucai Xiao committed
310
        auto a_lens       = args[0].get_shape().lens();
311
312
313
        std::size_t dim_0 = a_lens.size() - 2;
        std::size_t dim_1 = a_lens.size() - 1;
        bool trans        = args[0].get_shape().transposed();
Shucai Xiao's avatar
Shucai Xiao committed
314
315
316
317
        rocblas_int m     = a_lens[trans ? dim_1 : dim_0];
        rocblas_int n     = a_lens[trans ? dim_0 : dim_1];
        float beta        = 0.0f;
        rocblas_int lda   = args[0].get_shape().strides()[trans ? dim_1 : dim_0];
318
319

        assert(a_lens.back() == args[1].get_shape().elements());
Shucai Xiao's avatar
Shucai Xiao committed
320
321
        std::size_t batch_num = std::accumulate(
            a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
322
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
323
            auto alpha_r = to_rocblas_type(as(op.alpha));
Shucai Xiao's avatar
Shucai Xiao committed
324
            auto beta_r = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
325
326
327
328
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };
            for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
329
            {
330
331
332
333
334
335
336
337
338
339
340
                generic_rocblas_gemv(as,
                                     ctx.get_stream().get_rocblas(),
                                     trans ? rocblas_operation_transpose : rocblas_operation_none,
                                     m,
                                     n,
                                     &alpha_r,
                                     to_pointer(args[0], batch_no * m * n),
                                     lda,
                                     to_pointer(args[1]),
                                     1,
                                     &beta_r,
Shucai Xiao's avatar
Shucai Xiao committed
341
                                     to_pointer(args[2], batch_no * n), 1);
342
343
344
            }
        });
    }
345
    // vector * matrix
Shucai Xiao's avatar
Shucai Xiao committed
346
    else if(args[0].get_shape().lens().size() == 1)
347
    {
Shucai Xiao's avatar
Shucai Xiao committed
348
        auto b_lens       = args[1].get_shape().lens();
349
350
351
        std::size_t dim_0 = b_lens.size() - 2;
        std::size_t dim_1 = b_lens.size() - 1;
        bool trans        = !args[1].get_shape().transposed();
Shucai Xiao's avatar
Shucai Xiao committed
352
353
354
355
        rocblas_int m     = b_lens[trans ? dim_1 : dim_0];
        rocblas_int n     = b_lens[trans ? dim_0 : dim_1];
        float beta        = 0.0f;
        rocblas_int lda   = args[1].get_shape().strides()[trans ? dim_1 : dim_0];
356
357

        assert(b_lens.back() == args[0].get_shape().elements());
Shucai Xiao's avatar
Shucai Xiao committed
358
359
        std::size_t batch_num = std::accumulate(
            b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
360
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
361
            auto alpha_r = to_rocblas_type(as(op.alpha));
Shucai Xiao's avatar
Shucai Xiao committed
362
            auto beta_r = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
363
364
365
366
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };
            for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
367
368
            {
                generic_rocblas_gemv(as,
Shucai Xiao's avatar
Shucai Xiao committed
369
                                     ctx.get_stream().get_rocblas(),
370
                                     trans ? rocblas_operation_transpose : rocblas_operation_none,
Shucai Xiao's avatar
Shucai Xiao committed
371
372
373
374
375
                                     n,
                                     m,
                                     &alpha_r,
                                     to_pointer(args[0]),
                                     lda,
376
377
                                     to_pointer(args[1], batch_no * m * n),
                                     1,
Shucai Xiao's avatar
Shucai Xiao committed
378
                                     &beta_r,
Shucai Xiao's avatar
Shucai Xiao committed
379
                                     to_pointer(args[2], batch_no * m), 1);
380
381
382
383
384
385
            }
        });
    }
    // (batch) matrix multiplication
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
386
387
388
389
        bool transa   = args[0].get_shape().transposed();
        bool transb   = args[1].get_shape().transposed();
        auto a_lens   = args[0].get_shape().lens();
        auto b_lens   = args[1].get_shape().lens();
390
391
        auto out_lens = output_shape.lens();

Shucai Xiao's avatar
Shucai Xiao committed
392
393
394
395
396
397
398
399
        rocblas_int lda =
            args[0].get_shape().strides()[transa ? a_lens.size() - 1 : a_lens.size() - 2];
        rocblas_int ldb =
            args[1].get_shape().strides()[transb ? b_lens.size() - 1 : b_lens.size() - 2];
        rocblas_int ldc = args[2].get_shape().strides()[out_lens.size() - 2];
        rocblas_int m   = out_lens[out_lens.size() - 2];
        rocblas_int n   = out_lens[out_lens.size() - 1];
        rocblas_int k   = args[0].get_shape().lens()[a_lens.size() - 1];
Shucai Xiao's avatar
Shucai Xiao committed
400
        float beta        = 0.0f;
401
402
        auto input_dims = std::min(a_lens.size(), b_lens.size());
        std::size_t axis{0};
Shucai Xiao's avatar
Shucai Xiao committed
403
        for(axis = 2; axis < input_dims; ++axis)
404
        {
Shucai Xiao's avatar
Shucai Xiao committed
405
            if(a_lens[a_lens.size() - axis] != b_lens[b_lens.size() - axis])
406
407
408
409
410
411
            {
                break;
            }
        }

        // The number of matrices that can be computed in one call
Shucai Xiao's avatar
Shucai Xiao committed
412
        // batch_num > 1, we need to call the batch_gemm function,
413
        // otherwise, call the gemm function directly
Shucai Xiao's avatar
Shucai Xiao committed
414
415
416
417
418
        std::size_t num_matrices =
            std::accumulate(a_lens.rbegin() + 2,
                            (axis == a_lens.size() ? a_lens.rend() : a_lens.rbegin() + axis),
                            std::size_t{1},
                            std::multiplies<std::size_t>());
419
420
        std::size_t a_len_diff = out_lens.size() - a_lens.size();
        std::size_t b_len_diff = out_lens.size() - b_lens.size();
Shucai Xiao's avatar
Shucai Xiao committed
421
422
423
424
425
426
        std::vector<std::size_t> a_batch_lens(a_lens.begin(),
                                              a_lens.begin() + a_lens.size() - axis);
        std::vector<std::size_t> b_batch_lens(b_lens.begin(),
                                              b_lens.begin() + b_lens.size() - axis);
        std::vector<std::size_t> out_batch_lens(out_lens.begin(),
                                                out_lens.begin() + out_lens.size() - axis);
427
428
429
430

        shape::type_t t = output_shape.type();
        shape a_batch_shape{t, a_batch_lens};
        shape b_batch_shape{t, b_batch_lens};
Shucai Xiao's avatar
Shucai Xiao committed
431
        shape out_batch_shape{t, out_batch_lens};
432

Shucai Xiao's avatar
Shucai Xiao committed
433
        shape_for_each(out_batch_shape, [&](auto out_idx) {
434
435
436
            std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
            std::vector<std::size_t> a_idx(a_lens.size() - axis);
            std::vector<std::size_t> b_idx(b_lens.size() - axis);
Shucai Xiao's avatar
Shucai Xiao committed
437
438
439
440
441
442
443
444
445
446
            std::transform(out_idx.begin() + a_len_diff,
                           out_idx.end(),
                           a_batch_lens.begin(),
                           a_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
            std::transform(out_idx.begin() + b_len_diff,
                           out_idx.end(),
                           b_batch_lens.begin(),
                           b_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
447
448
449
450
451

            std::size_t a_ind = a_batch_shape.index(a_idx.begin(), b_idx.end());
            std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());

            output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
452
                auto alpha_r = to_rocblas_type(as(op.alpha));
Shucai Xiao's avatar
Shucai Xiao committed
453
                auto beta_r = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
                auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                    return to_rocblas_type(as.from(arg.data() + offset));
                };
                generic_rocblas_batched_gemm(
                    as,
                    ctx.get_stream().get_rocblas(),
                    transb ? rocblas_operation_transpose : rocblas_operation_none,
                    transa ? rocblas_operation_transpose : rocblas_operation_none,
                    n,
                    m,
                    k,
                    &alpha_r,
                    to_pointer(args[1], k * n * num_matrices * b_ind),
                    ldb,
                    k * n,
                    to_pointer(args[0], m * k * num_matrices * a_ind),
                    lda,
                    m * k,
                    &beta_r,
                    to_pointer(args[2], m * n * num_matrices * out_ind),
                    ldc,
                    m * n,
                    num_matrices);
477
478
479
            });
        });
    }
480

481
    return args[2];
wsttiger's avatar
wsttiger committed
482
483
484
}

} // namespace gpu
Paul's avatar
Paul committed
485
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
486
} // namespace migraphx