gemm.cpp 17.2 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
    rocblas_sscal(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
    rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    rocblas_haxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
    rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    rocblas_daxpy(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
110
template <class... Ts>
Paul's avatar
Paul committed
111
112
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
113
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
114
115
}

Paul's avatar
Paul committed
116
template <class... Ts>
Paul's avatar
Paul committed
117
118
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
119
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
120
121
}

Paul's avatar
Paul committed
122
template <class... Ts>
Paul's avatar
Paul committed
123
124
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
125
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
template <class T, class... Ts>
Paul's avatar
Paul committed
129
130
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
131
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
template <class T>
Paul's avatar
Paul committed
135
136
137
138
139
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
146
template <>
Paul's avatar
Paul committed
147
148
149
150
151
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
152
template <class T>
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
161
template <class T>
Paul's avatar
Paul committed
162
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
163
{
Paul's avatar
Paul committed
164
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
165
166
}

Paul's avatar
Paul committed
167
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
168

wsttiger's avatar
wsttiger committed
169
170
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
171
    return op.compute_shape(inputs);
wsttiger's avatar
wsttiger committed
172
}
173

Shucai Xiao's avatar
Shucai Xiao committed
174
175
176
177
void miopen_gemm::fill_result(context& ctx,
                              const shape& output_shape,
                              const argument& result,
                              const argument& c) const
178
{
179
    auto out_lens = output_shape.lens();
Shucai Xiao's avatar
Shucai Xiao committed
180
181
    auto c_lens   = c.get_shape().lens();
    if(output_shape == c.get_shape())
182
183
    {
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
184
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
185
            hipMemcpy(to_pointer(args[3]),
Shucai Xiao's avatar
Shucai Xiao committed
186
187
188
                      to_pointer(args[2]),
                      output_shape.bytes(),
                      hipMemcpyDeviceToDevice);
189
190
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
191
    else if(c.single())
192
193
194
195
196
197
198
199
200
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg, std::size_t offset) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
                hipMemcpy(to_pointer(args[3], i),
Shucai Xiao's avatar
Shucai Xiao committed
201
202
203
                          to_pointer(args[2]),
                          args[2].get_shape().bytes(),
                          hipMemcpyDeviceToDevice);
204
205
206
            }
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
207
    else if(c_lens.size() == 1 || (c_lens.size() == 2 && c_lens[1] == out_lens[1]))
208
209
210
211
212
213
214
215
216
217
218
    {
        auto m = out_lens[0];
        auto n = out_lens[1];
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg, std::size_t offset) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < m; ++i)
            {
                hipMemcpy(to_pointer(args[3], i * n),
Shucai Xiao's avatar
Shucai Xiao committed
219
220
221
                          to_pointer(args[2]),
                          args[2].get_shape().bytes(),
                          hipMemcpyDeviceToDevice);
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            }
        });
    }
    // case of c_lens.size() == 2 && c_len[0] == out_lens[0]
    else
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg, std::size_t offset) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
                hipMemcpy(to_pointer(args[3], i),
Shucai Xiao's avatar
Shucai Xiao committed
236
237
238
                          to_pointer(args[2], i / n),
                          args[2].get_shape().type_size(),
                          hipMemcpyDeviceToDevice);
239
240
241
            }
        });
    }
242
243
}

wsttiger's avatar
wsttiger committed
244
245
246
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
247
{
248
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
249
    if(is_3inputs)
250
251
    {
        fill_result(ctx, output_shape, args[3], args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
252

253
254
255
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(op.beta));
Shucai Xiao's avatar
Shucai Xiao committed
256
257
258
259
260
261
262
263
264
            bool transa     = args[0].get_shape().transposed();
            bool transb     = args[1].get_shape().transposed();
            rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
            rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
            rocblas_int ldc = args[2].get_shape().strides()[0];
            auto out_lens   = output_shape.lens();
            rocblas_int m   = out_lens[0];
            rocblas_int n   = out_lens[1];
            rocblas_int k   = args[0].get_shape().lens()[1];
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            generic_rocblas_gemm(as,
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 n,
                                 m,
                                 k,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
                                 to_pointer(args[0]),
                                 lda,
                                 &beta_r,
                                 to_pointer(args[2]),
                                 ldc);

        });

        return args[3];
    }
286

287
288
    // 2 input arguments cases
    // vector inner product
Shucai Xiao's avatar
Shucai Xiao committed
289
    if(output_shape.elements() == 1)
290
    {
291
        assert(args[0].get_shape().elements() == args[1].get_shape().elements());
292
293
294
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
295
296
297
            generic_rocblas_dot(as,
                                ctx.get_stream().get_rocblas(),
                                args[1].get_shape().elements(),
298
299
300
301
                                to_pointer(args[0]),
                                1,
                                to_pointer(args[1]),
                                1,
302
                                to_pointer(args[2]));
303

Shucai Xiao's avatar
Shucai Xiao committed
304
305
            generic_rocblas_scal(
                as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]));
306
                                 1);
307
308
        });
    }
309
    // matrix * vector
Shucai Xiao's avatar
Shucai Xiao committed
310
    else if(args[1].get_shape().lens().size() == 1)
311
    {
Shucai Xiao's avatar
Shucai Xiao committed
312
        auto a_lens       = args[0].get_shape().lens();
313
314
315
        std::size_t dim_0 = a_lens.size() - 2;
        std::size_t dim_1 = a_lens.size() - 1;
        bool trans        = args[0].get_shape().transposed();
Shucai Xiao's avatar
Shucai Xiao committed
316
317
318
319
        rocblas_int m     = a_lens[trans ? dim_1 : dim_0];
        rocblas_int n     = a_lens[trans ? dim_0 : dim_1];
        float beta        = 0.0f;
        rocblas_int lda   = args[0].get_shape().strides()[trans ? dim_1 : dim_0];
320
321

        assert(a_lens.back() == args[1].get_shape().elements());
Shucai Xiao's avatar
Shucai Xiao committed
322
323
        std::size_t batch_num = std::accumulate(
            a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
324
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
325
            auto alpha_r = to_rocblas_type(as(op.alpha));
326
            auto beta_r =   = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
327
328
329
330
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };
            for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
331
            {
332
333
334
335
336
337
338
339
340
341
342
                generic_rocblas_gemv(as,
                                     ctx.get_stream().get_rocblas(),
                                     trans ? rocblas_operation_transpose : rocblas_operation_none,
                                     m,
                                     n,
                                     &alpha_r,
                                     to_pointer(args[0], batch_no * m * n),
                                     lda,
                                     to_pointer(args[1]),
                                     1,
                                     &beta_r,
Shucai Xiao's avatar
Shucai Xiao committed
343
                                     to_pointer(args[2], batch_no * n) 1);
344
345
346
            }
        });
    }
347
    // vector * matrix
Shucai Xiao's avatar
Shucai Xiao committed
348
    else if(args[0].get_shape().lens().size() == 1)
349
    {
Shucai Xiao's avatar
Shucai Xiao committed
350
        auto b_lens       = args[1].get_shape().lens();
351
352
353
        std::size_t dim_0 = b_lens.size() - 2;
        std::size_t dim_1 = b_lens.size() - 1;
        bool trans        = !args[1].get_shape().transposed();
Shucai Xiao's avatar
Shucai Xiao committed
354
355
356
357
        rocblas_int m     = b_lens[trans ? dim_1 : dim_0];
        rocblas_int n     = b_lens[trans ? dim_0 : dim_1];
        float beta        = 0.0f;
        rocblas_int lda   = args[1].get_shape().strides()[trans ? dim_1 : dim_0];
358
359

        assert(b_lens.back() == args[0].get_shape().elements());
Shucai Xiao's avatar
Shucai Xiao committed
360
361
        std::size_t batch_num = std::accumulate(
            b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
362
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
363
            auto alpha_r = to_rocblas_type(as(op.alpha));
364
            auto beta_r =   = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
365
366
367
368
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };
            for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
369
370
            {
                generic_rocblas_gemv(as,
Shucai Xiao's avatar
Shucai Xiao committed
371
                                     ctx.get_stream().get_rocblas(),
372
                                     trans ? rocblas_operation_transpose : rocblas_operation_none,
Shucai Xiao's avatar
Shucai Xiao committed
373
374
375
376
377
                                     n,
                                     m,
                                     &alpha_r,
                                     to_pointer(args[0]),
                                     lda,
378
379
                                     to_pointer(args[1], batch_no * m * n),
                                     1,
Shucai Xiao's avatar
Shucai Xiao committed
380
                                     &beta_r,
Shucai Xiao's avatar
Shucai Xiao committed
381
                                     to_pointer(args[2], batch_no * m) 1);
382
383
384
385
386
387
            }
        });
    }
    // (batch) matrix multiplication
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
388
389
390
391
        bool transa   = args[0].get_shape().transposed();
        bool transb   = args[1].get_shape().transposed();
        auto a_lens   = args[0].get_shape().lens();
        auto b_lens   = args[1].get_shape().lens();
392
393
        auto out_lens = output_shape.lens();

Shucai Xiao's avatar
Shucai Xiao committed
394
395
396
397
398
399
400
401
        rocblas_int lda =
            args[0].get_shape().strides()[transa ? a_lens.size() - 1 : a_lens.size() - 2];
        rocblas_int ldb =
            args[1].get_shape().strides()[transb ? b_lens.size() - 1 : b_lens.size() - 2];
        rocblas_int ldc = args[2].get_shape().strides()[out_lens.size() - 2];
        rocblas_int m   = out_lens[out_lens.size() - 2];
        rocblas_int n   = out_lens[out_lens.size() - 1];
        rocblas_int k   = args[0].get_shape().lens()[a_lens.size() - 1];
402
403
        auto input_dims = std::min(a_lens.size(), b_lens.size());
        std::size_t axis{0};
Shucai Xiao's avatar
Shucai Xiao committed
404
        for(axis = 2; axis < input_dims; ++axis)
405
        {
Shucai Xiao's avatar
Shucai Xiao committed
406
            if(a_lens[a_lens.size() - axis] != b_lens[b_lens.size() - axis])
407
408
409
410
411
412
            {
                break;
            }
        }

        // The number of matrices that can be computed in one call
Shucai Xiao's avatar
Shucai Xiao committed
413
        // batch_num > 1, we need to call the batch_gemm function,
414
        // otherwise, call the gemm function directly
Shucai Xiao's avatar
Shucai Xiao committed
415
416
417
418
419
        std::size_t num_matrices =
            std::accumulate(a_lens.rbegin() + 2,
                            (axis == a_lens.size() ? a_lens.rend() : a_lens.rbegin() + axis),
                            std::size_t{1},
                            std::multiplies<std::size_t>());
420
421
        std::size_t a_len_diff = out_lens.size() - a_lens.size();
        std::size_t b_len_diff = out_lens.size() - b_lens.size();
Shucai Xiao's avatar
Shucai Xiao committed
422
423
424
425
426
427
        std::vector<std::size_t> a_batch_lens(a_lens.begin(),
                                              a_lens.begin() + a_lens.size() - axis);
        std::vector<std::size_t> b_batch_lens(b_lens.begin(),
                                              b_lens.begin() + b_lens.size() - axis);
        std::vector<std::size_t> out_batch_lens(out_lens.begin(),
                                                out_lens.begin() + out_lens.size() - axis);
428
429
430
431
432
433
434
435
436
437

        shape::type_t t = output_shape.type();
        shape a_batch_shape{t, a_batch_lens};
        shape b_batch_shape{t, b_batch_lens};
        shape out_diff_shape{t, out_batch_lens};

        shape_for_each(out_diff_shape, [&](auto out_idx) {
            std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
            std::vector<std::size_t> a_idx(a_lens.size() - axis);
            std::vector<std::size_t> b_idx(b_lens.size() - axis);
Shucai Xiao's avatar
Shucai Xiao committed
438
439
440
441
442
443
444
445
446
447
            std::transform(out_idx.begin() + a_len_diff,
                           out_idx.end(),
                           a_batch_lens.begin(),
                           a_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
            std::transform(out_idx.begin() + b_len_diff,
                           out_idx.end(),
                           b_batch_lens.begin(),
                           b_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
448
449
450
451
452

            std::size_t a_ind = a_batch_shape.index(a_idx.begin(), b_idx.end());
            std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());

            output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
453
                auto alpha_r = to_rocblas_type(as(op.alpha));
454
                auto beta_r =   = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
                auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                    return to_rocblas_type(as.from(arg.data() + offset));
                };
                generic_rocblas_batched_gemm(
                    as,
                    ctx.get_stream().get_rocblas(),
                    transb ? rocblas_operation_transpose : rocblas_operation_none,
                    transa ? rocblas_operation_transpose : rocblas_operation_none,
                    n,
                    m,
                    k,
                    &alpha_r,
                    to_pointer(args[1], k * n * num_matrices * b_ind),
                    ldb,
                    k * n,
                    to_pointer(args[0], m * k * num_matrices * a_ind),
                    lda,
                    m * k,
                    &beta_r,
                    to_pointer(args[2], m * n * num_matrices * out_ind),
                    ldc,
                    m * n,
                    num_matrices);
478
479
480
            });
        });
    }
481

482
    return args[2];
wsttiger's avatar
wsttiger committed
483
484
485
}

} // namespace gpu
Paul's avatar
Paul committed
486
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
487
} // namespace migraphx