gemm.cpp 18.9 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
    rocblas_sscal(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
    rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    rocblas_haxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
    rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    rocblas_daxpy(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
110
template <class... Ts>
Paul's avatar
Paul committed
111
112
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
113
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
114
115
}

Paul's avatar
Paul committed
116
template <class... Ts>
Paul's avatar
Paul committed
117
118
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
119
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
120
121
}

Paul's avatar
Paul committed
122
template <class... Ts>
Paul's avatar
Paul committed
123
124
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
125
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
template <class T, class... Ts>
Paul's avatar
Paul committed
129
130
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
131
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
template <class T>
Paul's avatar
Paul committed
135
136
137
138
139
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
146
template <>
Paul's avatar
Paul committed
147
148
149
150
151
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
152
template <class T>
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
161
template <class T>
Paul's avatar
Paul committed
162
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
163
{
Paul's avatar
Paul committed
164
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
165
166
}

Paul's avatar
Paul committed
167
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
168

wsttiger's avatar
wsttiger committed
169
170
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
171
172
    std::vector<shape> orig_inputs(inputs.begin(), inputs.begin() + inputs.size() - 1);
    return op.compute_shape(orig_inputs);
wsttiger's avatar
wsttiger committed
173
}
174

Shucai Xiao's avatar
Shucai Xiao committed
175
void miopen_gemm::fill_result(const shape& output_shape,
Shucai Xiao's avatar
Shucai Xiao committed
176
177
                              const argument& result,
                              const argument& c) const
178
{
Shucai Xiao's avatar
Shucai Xiao committed
179
180
    auto out_lens  = output_shape.lens();
    auto c_lens    = c.get_shape().lens();
181
    auto type_size = output_shape.type_size();
Shucai Xiao's avatar
Shucai Xiao committed
182
    if(output_shape == c.get_shape())
183
184
    {
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
185
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
186
187
            hipMemcpy(
                to_pointer(result), to_pointer(c), output_shape.bytes(), hipMemcpyDeviceToDevice);
188
189
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
190
    else if(c.single())
191
192
    {
        output_shape.visit_type([&](auto as) {
193
194
            auto to_pointer = [&](auto&& arg, std::size_t offset_byte = 0) {
                return to_rocblas_type(as.from(arg.data() + offset_byte));
195
196
197
198
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
199
                hipMemcpy(to_pointer(result, i * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
200
201
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
202
                          hipMemcpyDeviceToDevice);
203
204
205
            }
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
206
    else if(c_lens.size() == 1 || (c_lens.size() == 2 && c_lens[1] == out_lens[1]))
207
208
209
210
    {
        auto m = out_lens[0];
        auto n = out_lens[1];
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
211
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
212
213
214
215
216
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < m; ++i)
            {
217
                hipMemcpy(to_pointer(result, i * n * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
218
219
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
220
                          hipMemcpyDeviceToDevice);
221
222
223
224
225
226
227
228
229
230
231
232
233
            }
        });
    }
    // case of c_lens.size() == 2 && c_len[0] == out_lens[0]
    else
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg, std::size_t offset) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
234
235
236
                hipMemcpy(to_pointer(result, i * type_size),
                          to_pointer(c, i / out_lens[1] * type_size),
                          type_size,
Shucai Xiao's avatar
Shucai Xiao committed
237
                          hipMemcpyDeviceToDevice);
238
239
240
            }
        });
    }
241
242
}

243
244
245
246
argument miopen_gemm::batch_matmul(context& ctx,
                                   const shape& output_shape,
                                   const std::vector<argument>& args) const
{
Shucai Xiao's avatar
Shucai Xiao committed
247
248
    bool transa = args[0].get_shape().transposed();
    bool transb = args[1].get_shape().transposed();
249
250
251
252
253

    auto a_lens   = args[0].get_shape().lens();
    auto b_lens   = args[1].get_shape().lens();
    auto out_lens = output_shape.lens();

Shucai Xiao's avatar
Shucai Xiao committed
254
255
    auto an_dim   = a_lens.size();
    auto bn_dim   = b_lens.size();
256
257
258
259
260
261
262
263
264
265
266
267
    auto outn_dim = out_lens.size();

    rocblas_int lda = args[0].get_shape().strides()[transa ? an_dim - 1 : an_dim - 2];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? bn_dim - 1 : bn_dim - 2];
    rocblas_int ldc = args[2].get_shape().strides()[outn_dim - 2];
    rocblas_int m   = out_lens[outn_dim - 2];
    rocblas_int n   = out_lens[outn_dim - 1];
    rocblas_int k   = a_lens[an_dim - 1];
    float beta      = 0.0f;

    std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2);
    std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2);
Shucai Xiao's avatar
Shucai Xiao committed
268
    if(a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
269
    {
Shucai Xiao's avatar
Shucai Xiao committed
270
271
272
273
274
275
276
277
278
279
280
281
        std::size_t numa_matrices = std::accumulate(a_batch_lens.begin(),
                                                    a_batch_lens.end(),
                                                    std::size_t{1},
                                                    std::multiplies<std::size_t>());
        std::size_t numb_matrices = std::accumulate(b_batch_lens.begin(),
                                                    b_batch_lens.end(),
                                                    std::size_t{1},
                                                    std::multiplies<std::size_t>());
        std::size_t num_matrices  = std::max(numa_matrices, numb_matrices);
        rocblas_int stride_a      = (numa_matrices == 1) ? 0 : m * k;
        rocblas_int stride_b      = (numb_matrices == 1) ? 0 : k * n;
        rocblas_int stride_c      = m * n;
282
283
284
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
285
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                stride_b,
                to_pointer(args[0]),
                lda,
                stride_a,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                stride_c,
                num_matrices);
        });
    }
    else
    {
        std::vector<std::size_t> out_batch_lens(out_lens.begin(), out_lens.begin() + outn_dim - 2);
        shape::type_t t = output_shape.type();
        shape a_batch_shape{t, a_batch_lens};
        shape b_batch_shape{t, b_batch_lens};
        shape out_batch_shape{t, out_batch_lens};
        std::size_t a_len_diff = outn_dim - an_dim;
        std::size_t b_len_diff = outn_dim - bn_dim;

        shape_for_each(out_batch_shape, [&](auto out_idx) {
            std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
Shucai Xiao's avatar
Shucai Xiao committed
320
            auto type_size      = output_shape.type_size();
321
322
323
            std::vector<std::size_t> a_idx(a_batch_lens.size());
            std::vector<std::size_t> b_idx(b_batch_lens.size());
            std::transform(out_idx.begin() + a_len_diff,
Shucai Xiao's avatar
Shucai Xiao committed
324
325
326
327
                           out_idx.end(),
                           a_batch_lens.begin(),
                           a_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
328
            std::transform(out_idx.begin() + b_len_diff,
Shucai Xiao's avatar
Shucai Xiao committed
329
330
331
332
                           out_idx.end(),
                           b_batch_lens.begin(),
                           b_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
333
334
335
336
337
338
339
340
341
342

            std::size_t a_ind = a_batch_shape.index(a_idx.begin(), a_idx.end());
            std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());

            output_shape.visit_type([&](auto as) {
                auto alpha_r    = to_rocblas_type(as(op.alpha));
                auto beta_r     = to_rocblas_type(as(beta));
                auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                    return to_rocblas_type(as.from(arg.data() + offset));
                };
Shucai Xiao's avatar
Shucai Xiao committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                generic_rocblas_gemm(as,
                                     ctx.get_stream().get_rocblas(),
                                     transb ? rocblas_operation_transpose : rocblas_operation_none,
                                     transa ? rocblas_operation_transpose : rocblas_operation_none,
                                     n,
                                     m,
                                     k,
                                     &alpha_r,
                                     to_pointer(args[1], k * n * b_ind * type_size),
                                     ldb,
                                     to_pointer(args[0], m * k * a_ind * type_size),
                                     lda,
                                     &beta_r,
                                     to_pointer(args[2], m * n * out_ind * type_size),
                                     ldc);
358
359
360
361
362
363
364
            });
        });
    }

    return args[2];
}

wsttiger's avatar
wsttiger committed
365
366
367
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
368
{
369
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
370
    if(is_3inputs)
371
    {
372
        fill_result(output_shape, args[3], args[2]);        
373
        output_shape.visit_type([&](auto as) {
374
375
376
            auto n_dim =    output_shape.lens().size();
            auto dim_1 =    n_dim - 1;
            auto dim_0 =    n_dim - 2;
377
378
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(op.beta));
Shucai Xiao's avatar
Shucai Xiao committed
379
380
            bool transa     = args[0].get_shape().transposed();
            bool transb     = args[1].get_shape().transposed();
381
382
383
            rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
            rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
            rocblas_int ldc = args[3].get_shape().strides()[dim_0];
Shucai Xiao's avatar
Shucai Xiao committed
384
            auto out_lens   = output_shape.lens();
385
386
387
388
            rocblas_int m   = out_lens[dim_0];
            rocblas_int n   = out_lens[dim_1];
            rocblas_int k   = args[0].get_shape().lens()[dim_1];
            auto num_matrices = std::accumulate(out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
389
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
390
            generic_rocblas_batched_gemm(as,
391
392
393
394
395
396
397
398
399
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 n,
                                 m,
                                 k,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
400
                                 k * n,
401
402
                                 to_pointer(args[0]),
                                 lda,
403
                                 m * k,
404
                                 &beta_r,
405
                                 to_pointer(args[3]),
406
407
408
                                 ldc,
                                 m * n,
                                 num_matrices);
409
410
411
412
413

        });

        return args[3];
    }
414

415
416
    // 2 input arguments cases
    // vector inner product
Shucai Xiao's avatar
Shucai Xiao committed
417
    if(output_shape.elements() == 1)
418
    {
419
        assert(args[0].get_shape().elements() == args[1].get_shape().elements());
420
421
422
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
423
424
425
            generic_rocblas_dot(as,
                                ctx.get_stream().get_rocblas(),
                                args[1].get_shape().elements(),
426
427
428
429
                                to_pointer(args[0]),
                                1,
                                to_pointer(args[1]),
                                1,
430
                                to_pointer(args[2]));
431

Shucai Xiao's avatar
Shucai Xiao committed
432
            generic_rocblas_scal(
Shucai Xiao's avatar
Shucai Xiao committed
433
                as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]), 1);
434
435
        });
    }
436
    // matrix * vector
Shucai Xiao's avatar
Shucai Xiao committed
437
    else if(args[1].get_shape().lens().size() == 1)
438
    {
Shucai Xiao's avatar
Shucai Xiao committed
439
        auto a_lens       = args[0].get_shape().lens();
440
        auto b_lens       = args[1].get_shape().lens();
441
442
        std::size_t dim_0 = a_lens.size() - 2;
        std::size_t dim_1 = a_lens.size() - 1;
443
444
445
446
447
448
449
450
        bool transa       = args[0].get_shape().transposed();
        bool transb       = false;
        rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
        rocblas_int ldb   = 1;
        rocblas_int ldc   = 1;
        rocblas_int m     = a_lens[dim_0];
        rocblas_int n     = 1;
        rocblas_int k     = a_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
451
        float beta        = 0.0f;
452
        assert(a_lens.back() == args[1].get_shape().elements());
453

Shucai Xiao's avatar
Shucai Xiao committed
454
455
        std::size_t batch_num = std::accumulate(
            a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
456
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
457
458
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
459
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                0,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                m * n,
                batch_num);
481
482
        });
    }
483
    // vector * matrix
Shucai Xiao's avatar
Shucai Xiao committed
484
    else if(args[0].get_shape().lens().size() == 1)
485
    {
486
        auto a_lens       = args[0].get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
487
        auto b_lens       = args[1].get_shape().lens();
488
489
        std::size_t dim_0 = b_lens.size() - 2;
        std::size_t dim_1 = b_lens.size() - 1;
490
491
492
493
494
        bool transb       = args[1].get_shape().transposed();
        bool transa       = false;
        rocblas_int lda   = a_lens[0];
        rocblas_int ldb   = args[1].get_shape().strides()[(transb ? dim_1 : dim_0)];
        rocblas_int ldc   = b_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
495
496
497
        rocblas_int m     = 1;
        rocblas_int n     = args[1].get_shape().lens()[dim_1];
        rocblas_int k     = a_lens[0];
Shucai Xiao's avatar
Shucai Xiao committed
498
        float beta        = 0.0f;
499
        assert(b_lens[dim_0] == args[0].get_shape().elements());
500

Shucai Xiao's avatar
Shucai Xiao committed
501
502
        std::size_t batch_num = std::accumulate(
            b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
503

504
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
505
506
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
507
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
508
509

            generic_rocblas_batched_gemm(
Shucai Xiao's avatar
Shucai Xiao committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                0,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                m * n,
                batch_num);
529
530
531
532
533
        });
    }
    // (batch) matrix multiplication
    else
    {
534
        batch_matmul(ctx, output_shape, args);
535
    }
536

537
    return args[2];
wsttiger's avatar
wsttiger committed
538
539
540
}

} // namespace gpu
Paul's avatar
Paul committed
541
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
542
} // namespace migraphx