"test/gpu/ops_test.cpp" did not exist on "77cc8d160c30c3c98e6cf23a42a175475eb02753"
gemm.cpp 18.5 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
    rocblas_sscal(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
    rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    rocblas_haxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
    rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    rocblas_daxpy(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
110
template <class... Ts>
Paul's avatar
Paul committed
111
112
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
113
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
114
115
}

Paul's avatar
Paul committed
116
template <class... Ts>
Paul's avatar
Paul committed
117
118
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
119
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
120
121
}

Paul's avatar
Paul committed
122
template <class... Ts>
Paul's avatar
Paul committed
123
124
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
125
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
template <class T, class... Ts>
Paul's avatar
Paul committed
129
130
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
131
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
template <class T>
Paul's avatar
Paul committed
135
136
137
138
139
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
146
template <>
Paul's avatar
Paul committed
147
148
149
150
151
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
152
template <class T>
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
161
template <class T>
Paul's avatar
Paul committed
162
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
163
{
Paul's avatar
Paul committed
164
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
165
166
}

Paul's avatar
Paul committed
167
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
168

wsttiger's avatar
wsttiger committed
169
170
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
171
172
    std::vector<shape> orig_inputs(inputs.begin(), inputs.begin() + inputs.size() - 1);
    return op.compute_shape(orig_inputs);
wsttiger's avatar
wsttiger committed
173
}
174

Shucai Xiao's avatar
Shucai Xiao committed
175
void miopen_gemm::fill_result(const shape& output_shape,
Shucai Xiao's avatar
Shucai Xiao committed
176
177
                              const argument& result,
                              const argument& c) const
178
{
Shucai Xiao's avatar
Shucai Xiao committed
179
180
    auto out_lens  = output_shape.lens();
    auto c_lens    = c.get_shape().lens();
181
    auto type_size = output_shape.type_size();
Shucai Xiao's avatar
Shucai Xiao committed
182
    if(output_shape == c.get_shape())
183
184
    {
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
185
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
186
187
            hipMemcpy(
                to_pointer(result), to_pointer(c), output_shape.bytes(), hipMemcpyDeviceToDevice);
188
189
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
190
    else if(c.single())
191
192
    {
        output_shape.visit_type([&](auto as) {
193
194
            auto to_pointer = [&](auto&& arg, std::size_t offset_byte = 0) {
                return to_rocblas_type(as.from(arg.data() + offset_byte));
195
196
197
198
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
199
                hipMemcpy(to_pointer(result, i * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
200
201
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
202
                          hipMemcpyDeviceToDevice);
203
204
205
            }
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
206
    else if(c_lens.size() == 1 || (c_lens.size() == 2 && c_lens[1] == out_lens[1]))
207
208
209
210
    {
        auto m = out_lens[0];
        auto n = out_lens[1];
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
211
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
212
213
214
215
216
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < m; ++i)
            {
217
                hipMemcpy(to_pointer(result, i * n * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
218
219
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
220
                          hipMemcpyDeviceToDevice);
221
222
223
224
225
226
227
228
229
230
231
232
233
            }
        });
    }
    // case of c_lens.size() == 2 && c_len[0] == out_lens[0]
    else
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg, std::size_t offset) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
234
235
236
                hipMemcpy(to_pointer(result, i * type_size),
                          to_pointer(c, i / out_lens[1] * type_size),
                          type_size,
Shucai Xiao's avatar
Shucai Xiao committed
237
                          hipMemcpyDeviceToDevice);
238
239
240
            }
        });
    }
241
242
}

243
244
245
246
argument miopen_gemm::batch_matmul(context& ctx,
                                   const shape& output_shape,
                                   const std::vector<argument>& args) const
{
Shucai Xiao's avatar
Shucai Xiao committed
247
248
    bool transa = args[0].get_shape().transposed();
    bool transb = args[1].get_shape().transposed();
249
250
251
252
253

    auto a_lens   = args[0].get_shape().lens();
    auto b_lens   = args[1].get_shape().lens();
    auto out_lens = output_shape.lens();

Shucai Xiao's avatar
Shucai Xiao committed
254
255
    auto an_dim   = a_lens.size();
    auto bn_dim   = b_lens.size();
256
257
258
259
260
261
262
263
264
265
266
267
    auto outn_dim = out_lens.size();

    rocblas_int lda = args[0].get_shape().strides()[transa ? an_dim - 1 : an_dim - 2];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? bn_dim - 1 : bn_dim - 2];
    rocblas_int ldc = args[2].get_shape().strides()[outn_dim - 2];
    rocblas_int m   = out_lens[outn_dim - 2];
    rocblas_int n   = out_lens[outn_dim - 1];
    rocblas_int k   = a_lens[an_dim - 1];
    float beta      = 0.0f;

    std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2);
    std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2);
Shucai Xiao's avatar
Shucai Xiao committed
268
    if(a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
269
    {
Shucai Xiao's avatar
Shucai Xiao committed
270
271
272
273
274
275
276
277
278
279
280
281
        std::size_t numa_matrices = std::accumulate(a_batch_lens.begin(),
                                                    a_batch_lens.end(),
                                                    std::size_t{1},
                                                    std::multiplies<std::size_t>());
        std::size_t numb_matrices = std::accumulate(b_batch_lens.begin(),
                                                    b_batch_lens.end(),
                                                    std::size_t{1},
                                                    std::multiplies<std::size_t>());
        std::size_t num_matrices  = std::max(numa_matrices, numb_matrices);
        rocblas_int stride_a      = (numa_matrices == 1) ? 0 : m * k;
        rocblas_int stride_b      = (numb_matrices == 1) ? 0 : k * n;
        rocblas_int stride_c      = m * n;
282
283
284
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
285
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                stride_b,
                to_pointer(args[0]),
                lda,
                stride_a,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                stride_c,
                num_matrices);
        });
    }
    else
    {
        std::vector<std::size_t> out_batch_lens(out_lens.begin(), out_lens.begin() + outn_dim - 2);
        shape::type_t t = output_shape.type();
        shape a_batch_shape{t, a_batch_lens};
        shape b_batch_shape{t, b_batch_lens};
        shape out_batch_shape{t, out_batch_lens};
        std::size_t a_len_diff = outn_dim - an_dim;
        std::size_t b_len_diff = outn_dim - bn_dim;

        shape_for_each(out_batch_shape, [&](auto out_idx) {
            std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
Shucai Xiao's avatar
Shucai Xiao committed
320
            auto type_size      = output_shape.type_size();
321
322
323
            std::vector<std::size_t> a_idx(a_batch_lens.size());
            std::vector<std::size_t> b_idx(b_batch_lens.size());
            std::transform(out_idx.begin() + a_len_diff,
Shucai Xiao's avatar
Shucai Xiao committed
324
325
326
327
                           out_idx.end(),
                           a_batch_lens.begin(),
                           a_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
328
            std::transform(out_idx.begin() + b_len_diff,
Shucai Xiao's avatar
Shucai Xiao committed
329
330
331
332
                           out_idx.end(),
                           b_batch_lens.begin(),
                           b_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
333
334
335
336
337
338
339
340
341
342

            std::size_t a_ind = a_batch_shape.index(a_idx.begin(), a_idx.end());
            std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());

            output_shape.visit_type([&](auto as) {
                auto alpha_r    = to_rocblas_type(as(op.alpha));
                auto beta_r     = to_rocblas_type(as(beta));
                auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                    return to_rocblas_type(as.from(arg.data() + offset));
                };
Shucai Xiao's avatar
Shucai Xiao committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                generic_rocblas_gemm(as,
                                     ctx.get_stream().get_rocblas(),
                                     transb ? rocblas_operation_transpose : rocblas_operation_none,
                                     transa ? rocblas_operation_transpose : rocblas_operation_none,
                                     n,
                                     m,
                                     k,
                                     &alpha_r,
                                     to_pointer(args[1], k * n * b_ind * type_size),
                                     ldb,
                                     to_pointer(args[0], m * k * a_ind * type_size),
                                     lda,
                                     &beta_r,
                                     to_pointer(args[2], m * n * out_ind * type_size),
                                     ldc);
358
359
360
361
362
363
364
            });
        });
    }

    return args[2];
}

wsttiger's avatar
wsttiger committed
365
366
367
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
368
{
369
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
370
    if(is_3inputs)
371
    {
Shucai Xiao's avatar
Shucai Xiao committed
372
        fill_result(output_shape, args[3], args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
373

374
375
376
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(op.beta));
Shucai Xiao's avatar
Shucai Xiao committed
377
378
379
380
            bool transa     = args[0].get_shape().transposed();
            bool transb     = args[1].get_shape().transposed();
            rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
            rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
381
            rocblas_int ldc = args[3].get_shape().strides()[0];
Shucai Xiao's avatar
Shucai Xiao committed
382
383
384
385
            auto out_lens   = output_shape.lens();
            rocblas_int m   = out_lens[0];
            rocblas_int n   = out_lens[1];
            rocblas_int k   = args[0].get_shape().lens()[1];
386
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
387

388
389
390
391
392
393
394
395
396
397
398
399
400
            generic_rocblas_gemm(as,
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 n,
                                 m,
                                 k,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
                                 to_pointer(args[0]),
                                 lda,
                                 &beta_r,
401
                                 to_pointer(args[3]),
402
403
404
405
406
407
                                 ldc);

        });

        return args[3];
    }
408

409
410
    // 2 input arguments cases
    // vector inner product
Shucai Xiao's avatar
Shucai Xiao committed
411
    if(output_shape.elements() == 1)
412
    {
413
        assert(args[0].get_shape().elements() == args[1].get_shape().elements());
414
415
416
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
417
418
419
            generic_rocblas_dot(as,
                                ctx.get_stream().get_rocblas(),
                                args[1].get_shape().elements(),
420
421
422
423
                                to_pointer(args[0]),
                                1,
                                to_pointer(args[1]),
                                1,
424
                                to_pointer(args[2]));
425

Shucai Xiao's avatar
Shucai Xiao committed
426
            generic_rocblas_scal(
Shucai Xiao's avatar
Shucai Xiao committed
427
                as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]), 1);
428
429
        });
    }
430
    // matrix * vector
Shucai Xiao's avatar
Shucai Xiao committed
431
    else if(args[1].get_shape().lens().size() == 1)
432
    {
Shucai Xiao's avatar
Shucai Xiao committed
433
        auto a_lens       = args[0].get_shape().lens();
434
        auto b_lens       = args[1].get_shape().lens();
435
436
        std::size_t dim_0 = a_lens.size() - 2;
        std::size_t dim_1 = a_lens.size() - 1;
437
438
439
440
441
442
443
444
        bool transa       = args[0].get_shape().transposed();
        bool transb       = false;
        rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
        rocblas_int ldb   = 1;
        rocblas_int ldc   = 1;
        rocblas_int m     = a_lens[dim_0];
        rocblas_int n     = 1;
        rocblas_int k     = a_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
445
        float beta        = 0.0f;
446
        assert(a_lens.back() == args[1].get_shape().elements());
447

Shucai Xiao's avatar
Shucai Xiao committed
448
449
        std::size_t batch_num = std::accumulate(
            a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
450
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
451
452
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
453
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474

            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                0,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                m * n,
                batch_num);
475
476
        });
    }
477
    // vector * matrix
Shucai Xiao's avatar
Shucai Xiao committed
478
    else if(args[0].get_shape().lens().size() == 1)
479
    {
480
        auto a_lens       = args[0].get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
481
        auto b_lens       = args[1].get_shape().lens();
482
483
        std::size_t dim_0 = b_lens.size() - 2;
        std::size_t dim_1 = b_lens.size() - 1;
484
485
486
487
488
        bool transb       = args[1].get_shape().transposed();
        bool transa       = false;
        rocblas_int lda   = a_lens[0];
        rocblas_int ldb   = args[1].get_shape().strides()[(transb ? dim_1 : dim_0)];
        rocblas_int ldc   = b_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
489
490
491
        rocblas_int m     = 1;
        rocblas_int n     = args[1].get_shape().lens()[dim_1];
        rocblas_int k     = a_lens[0];
Shucai Xiao's avatar
Shucai Xiao committed
492
        float beta        = 0.0f;
493
        assert(b_lens[dim_0] == args[0].get_shape().elements());
494

Shucai Xiao's avatar
Shucai Xiao committed
495
496
        std::size_t batch_num = std::accumulate(
            b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
497

498
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
499
500
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
501
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
502
503

            generic_rocblas_batched_gemm(
Shucai Xiao's avatar
Shucai Xiao committed
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                0,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                m * n,
                batch_num);
523
524
525
526
527
        });
    }
    // (batch) matrix multiplication
    else
    {
528
        batch_matmul(ctx, output_shape, args);
529
    }
530

531
    return args[2];
wsttiger's avatar
wsttiger committed
532
533
534
}

} // namespace gpu
Paul's avatar
Paul committed
535
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
536
} // namespace migraphx