gemm.cpp 18.8 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
    rocblas_sscal(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
    rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    rocblas_haxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
    rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    rocblas_daxpy(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
110
template <class... Ts>
Paul's avatar
Paul committed
111
112
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
113
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
114
115
}

Paul's avatar
Paul committed
116
template <class... Ts>
Paul's avatar
Paul committed
117
118
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
119
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
120
121
}

Paul's avatar
Paul committed
122
template <class... Ts>
Paul's avatar
Paul committed
123
124
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
125
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
template <class T, class... Ts>
Paul's avatar
Paul committed
129
130
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
131
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
template <class T>
Paul's avatar
Paul committed
135
136
137
138
139
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
146
template <>
Paul's avatar
Paul committed
147
148
149
150
151
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
152
template <class T>
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
161
template <class T>
Paul's avatar
Paul committed
162
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
163
{
Paul's avatar
Paul committed
164
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
165
166
}

Paul's avatar
Paul committed
167
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
168

wsttiger's avatar
wsttiger committed
169
170
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
171
172
    std::vector<shape> orig_inputs(inputs.begin(), inputs.begin() + inputs.size() - 1);
    return op.compute_shape(orig_inputs);
wsttiger's avatar
wsttiger committed
173
}
174

Shucai Xiao's avatar
Shucai Xiao committed
175
void miopen_gemm::fill_result(const shape& output_shape,
Shucai Xiao's avatar
Shucai Xiao committed
176
177
                              const argument& result,
                              const argument& c) const
178
{
Shucai Xiao's avatar
Shucai Xiao committed
179
180
    auto out_lens  = output_shape.lens();
    auto c_lens    = c.get_shape().lens();
181
    auto type_size = output_shape.type_size();
Shucai Xiao's avatar
Shucai Xiao committed
182
    if(output_shape == c.get_shape())
183
184
    {
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
185
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
186
187
            hipMemcpy(
                to_pointer(result), to_pointer(c), output_shape.bytes(), hipMemcpyDeviceToDevice);
188
189
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
190
    else if(c.single())
191
192
    {
        output_shape.visit_type([&](auto as) {
193
194
            auto to_pointer = [&](auto&& arg, std::size_t offset_byte = 0) {
                return to_rocblas_type(as.from(arg.data() + offset_byte));
195
196
197
198
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
199
                hipMemcpy(to_pointer(result, i * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
200
201
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
202
                          hipMemcpyDeviceToDevice);
203
204
205
            }
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
206
    else if(c_lens.size() == 1 || (c_lens.size() == 2 && c_lens[1] == out_lens[1]))
207
208
209
210
    {
        auto m = out_lens[0];
        auto n = out_lens[1];
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
211
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
212
213
214
215
216
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < m; ++i)
            {
217
                hipMemcpy(to_pointer(result, i * n * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
218
219
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
220
                          hipMemcpyDeviceToDevice);
221
222
223
224
225
226
227
228
229
230
231
232
233
            }
        });
    }
    // case of c_lens.size() == 2 && c_len[0] == out_lens[0]
    else
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg, std::size_t offset) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
234
235
236
                hipMemcpy(to_pointer(result, i * type_size),
                          to_pointer(c, i / out_lens[1] * type_size),
                          type_size,
Shucai Xiao's avatar
Shucai Xiao committed
237
                          hipMemcpyDeviceToDevice);
238
239
240
            }
        });
    }
241
242
}

243
244
245
246
argument miopen_gemm::batch_matmul(context& ctx,
                                   const shape& output_shape,
                                   const std::vector<argument>& args) const
{
Shucai Xiao's avatar
Shucai Xiao committed
247
248
    bool transa = args[0].get_shape().transposed();
    bool transb = args[1].get_shape().transposed();
249
250
251
252
253

    auto a_lens   = args[0].get_shape().lens();
    auto b_lens   = args[1].get_shape().lens();
    auto out_lens = output_shape.lens();

Shucai Xiao's avatar
Shucai Xiao committed
254
255
    auto an_dim   = a_lens.size();
    auto bn_dim   = b_lens.size();
256
257
258
259
260
261
262
263
264
265
266
267
    auto outn_dim = out_lens.size();

    rocblas_int lda = args[0].get_shape().strides()[transa ? an_dim - 1 : an_dim - 2];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? bn_dim - 1 : bn_dim - 2];
    rocblas_int ldc = args[2].get_shape().strides()[outn_dim - 2];
    rocblas_int m   = out_lens[outn_dim - 2];
    rocblas_int n   = out_lens[outn_dim - 1];
    rocblas_int k   = a_lens[an_dim - 1];
    float beta      = 0.0f;

    std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2);
    std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2);
Shucai Xiao's avatar
Shucai Xiao committed
268
    if(a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
269
    {
Shucai Xiao's avatar
Shucai Xiao committed
270
271
272
273
274
275
276
277
278
279
280
281
        std::size_t numa_matrices = std::accumulate(a_batch_lens.begin(),
                                                    a_batch_lens.end(),
                                                    std::size_t{1},
                                                    std::multiplies<std::size_t>());
        std::size_t numb_matrices = std::accumulate(b_batch_lens.begin(),
                                                    b_batch_lens.end(),
                                                    std::size_t{1},
                                                    std::multiplies<std::size_t>());
        std::size_t num_matrices  = std::max(numa_matrices, numb_matrices);
        rocblas_int stride_a      = (numa_matrices == 1) ? 0 : m * k;
        rocblas_int stride_b      = (numb_matrices == 1) ? 0 : k * n;
        rocblas_int stride_c      = m * n;
282
283
284
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
285
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                stride_b,
                to_pointer(args[0]),
                lda,
                stride_a,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                stride_c,
                num_matrices);
        });
    }
    else
    {
        std::vector<std::size_t> out_batch_lens(out_lens.begin(), out_lens.begin() + outn_dim - 2);
        shape::type_t t = output_shape.type();
        shape a_batch_shape{t, a_batch_lens};
        shape b_batch_shape{t, b_batch_lens};
        shape out_batch_shape{t, out_batch_lens};
        std::size_t a_len_diff = outn_dim - an_dim;
        std::size_t b_len_diff = outn_dim - bn_dim;

        shape_for_each(out_batch_shape, [&](auto out_idx) {
            std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
Shucai Xiao's avatar
Shucai Xiao committed
320
            auto type_size      = output_shape.type_size();
321
322
323
            std::vector<std::size_t> a_idx(a_batch_lens.size());
            std::vector<std::size_t> b_idx(b_batch_lens.size());
            std::transform(out_idx.begin() + a_len_diff,
Shucai Xiao's avatar
Shucai Xiao committed
324
325
326
327
                           out_idx.end(),
                           a_batch_lens.begin(),
                           a_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
328
            std::transform(out_idx.begin() + b_len_diff,
Shucai Xiao's avatar
Shucai Xiao committed
329
330
331
332
                           out_idx.end(),
                           b_batch_lens.begin(),
                           b_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
333
334
335
336
337
338
339
340
341
342

            std::size_t a_ind = a_batch_shape.index(a_idx.begin(), a_idx.end());
            std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());

            output_shape.visit_type([&](auto as) {
                auto alpha_r    = to_rocblas_type(as(op.alpha));
                auto beta_r     = to_rocblas_type(as(beta));
                auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                    return to_rocblas_type(as.from(arg.data() + offset));
                };
Shucai Xiao's avatar
Shucai Xiao committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                generic_rocblas_gemm(as,
                                     ctx.get_stream().get_rocblas(),
                                     transb ? rocblas_operation_transpose : rocblas_operation_none,
                                     transa ? rocblas_operation_transpose : rocblas_operation_none,
                                     n,
                                     m,
                                     k,
                                     &alpha_r,
                                     to_pointer(args[1], k * n * b_ind * type_size),
                                     ldb,
                                     to_pointer(args[0], m * k * a_ind * type_size),
                                     lda,
                                     &beta_r,
                                     to_pointer(args[2], m * n * out_ind * type_size),
                                     ldc);
358
359
360
361
362
363
364
            });
        });
    }

    return args[2];
}

wsttiger's avatar
wsttiger committed
365
366
367
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
368
{
369
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
370
    if(is_3inputs)
371
    {
Shucai Xiao's avatar
Shucai Xiao committed
372
        fill_result(output_shape, args[3], args[2]);
373
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
            auto n_dim        = output_shape.lens().size();
            auto dim_1        = n_dim - 1;
            auto dim_0        = n_dim - 2;
            auto alpha_r      = to_rocblas_type(as(op.alpha));
            auto beta_r       = to_rocblas_type(as(op.beta));
            bool transa       = args[0].get_shape().transposed();
            bool transb       = args[1].get_shape().transposed();
            rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
            rocblas_int ldb   = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
            rocblas_int ldc   = args[3].get_shape().strides()[dim_0];
            auto out_lens     = output_shape.lens();
            rocblas_int m     = out_lens[dim_0];
            rocblas_int n     = out_lens[dim_1];
            rocblas_int k     = args[0].get_shape().lens()[dim_1];
            auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
                                                out_lens.rend(),
                                                std::size_t{1},
                                                std::multiplies<std::size_t>());
            auto to_pointer   = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
                to_pointer(args[3]),
                ldc,
                m * n,
                num_matrices);
413
414
415
416
417

        });

        return args[3];
    }
418

419
420
    // 2 input arguments cases
    // vector inner product
Shucai Xiao's avatar
Shucai Xiao committed
421
    if(output_shape.elements() == 1)
422
    {
423
        assert(args[0].get_shape().elements() == args[1].get_shape().elements());
424
425
426
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
427
428
429
            generic_rocblas_dot(as,
                                ctx.get_stream().get_rocblas(),
                                args[1].get_shape().elements(),
430
431
432
433
                                to_pointer(args[0]),
                                1,
                                to_pointer(args[1]),
                                1,
434
                                to_pointer(args[2]));
435

Shucai Xiao's avatar
Shucai Xiao committed
436
            generic_rocblas_scal(
Shucai Xiao's avatar
Shucai Xiao committed
437
                as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]), 1);
438
439
        });
    }
440
    // matrix * vector
Shucai Xiao's avatar
Shucai Xiao committed
441
    else if(args[1].get_shape().lens().size() == 1)
442
    {
Shucai Xiao's avatar
Shucai Xiao committed
443
        auto a_lens       = args[0].get_shape().lens();
444
        auto b_lens       = args[1].get_shape().lens();
445
446
        std::size_t dim_0 = a_lens.size() - 2;
        std::size_t dim_1 = a_lens.size() - 1;
447
448
449
450
451
452
453
454
        bool transa       = args[0].get_shape().transposed();
        bool transb       = false;
        rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
        rocblas_int ldb   = 1;
        rocblas_int ldc   = 1;
        rocblas_int m     = a_lens[dim_0];
        rocblas_int n     = 1;
        rocblas_int k     = a_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
455
        float beta        = 0.0f;
456
        assert(a_lens.back() == args[1].get_shape().elements());
457

Shucai Xiao's avatar
Shucai Xiao committed
458
459
        std::size_t batch_num = std::accumulate(
            a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
460
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
461
462
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
463
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                0,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                m * n,
                batch_num);
485
486
        });
    }
487
    // vector * matrix
Shucai Xiao's avatar
Shucai Xiao committed
488
    else if(args[0].get_shape().lens().size() == 1)
489
    {
490
        auto a_lens       = args[0].get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
491
        auto b_lens       = args[1].get_shape().lens();
492
493
        std::size_t dim_0 = b_lens.size() - 2;
        std::size_t dim_1 = b_lens.size() - 1;
494
495
496
497
498
        bool transb       = args[1].get_shape().transposed();
        bool transa       = false;
        rocblas_int lda   = a_lens[0];
        rocblas_int ldb   = args[1].get_shape().strides()[(transb ? dim_1 : dim_0)];
        rocblas_int ldc   = b_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
499
500
501
        rocblas_int m     = 1;
        rocblas_int n     = args[1].get_shape().lens()[dim_1];
        rocblas_int k     = a_lens[0];
Shucai Xiao's avatar
Shucai Xiao committed
502
        float beta        = 0.0f;
503
        assert(b_lens[dim_0] == args[0].get_shape().elements());
504

Shucai Xiao's avatar
Shucai Xiao committed
505
506
        std::size_t batch_num = std::accumulate(
            b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
507

508
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
509
510
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
511
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
512
513

            generic_rocblas_batched_gemm(
Shucai Xiao's avatar
Shucai Xiao committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                0,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                m * n,
                batch_num);
533
534
535
536
537
        });
    }
    // (batch) matrix multiplication
    else
    {
538
        batch_matmul(ctx, output_shape, args);
539
    }
540

541
    return args[2];
wsttiger's avatar
wsttiger committed
542
543
544
}

} // namespace gpu
Paul's avatar
Paul committed
545
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
546
} // namespace migraphx