gemm.cpp 20.4 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
    rocblas_sscal(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
    rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    rocblas_haxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
    rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    rocblas_daxpy(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
110
template <class... Ts>
Paul's avatar
Paul committed
111
112
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
113
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
114
115
}

Paul's avatar
Paul committed
116
template <class... Ts>
Paul's avatar
Paul committed
117
118
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
119
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
120
121
}

Paul's avatar
Paul committed
122
template <class... Ts>
Paul's avatar
Paul committed
123
124
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
125
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
template <class T, class... Ts>
Paul's avatar
Paul committed
129
130
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
131
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
template <class T>
Paul's avatar
Paul committed
135
136
137
138
139
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
146
template <>
Paul's avatar
Paul committed
147
148
149
150
151
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
152
template <class T>
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
161
template <class T>
Paul's avatar
Paul committed
162
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
163
{
Paul's avatar
Paul committed
164
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
165
166
}

Paul's avatar
Paul committed
167
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
168

wsttiger's avatar
wsttiger committed
169
170
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
171
172
    std::vector<shape> orig_inputs(inputs.begin(), inputs.begin() + inputs.size() - 1);
    return op.compute_shape(orig_inputs);
wsttiger's avatar
wsttiger committed
173
}
174

Shucai Xiao's avatar
Shucai Xiao committed
175
void miopen_gemm::fill_result(const shape& output_shape,
Shucai Xiao's avatar
Shucai Xiao committed
176
177
                              const argument& result,
                              const argument& c) const
178
{
Shucai Xiao's avatar
Shucai Xiao committed
179
180
    auto out_lens  = output_shape.lens();
    auto c_lens    = c.get_shape().lens();
181
    auto type_size = output_shape.type_size();
Shucai Xiao's avatar
Shucai Xiao committed
182
    if(output_shape == c.get_shape())
183
184
    {
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
185
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
186
187
            hipMemcpy(
                to_pointer(result), to_pointer(c), output_shape.bytes(), hipMemcpyDeviceToDevice);
188
189
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
190
    else if(c.single())
191
192
    {
        output_shape.visit_type([&](auto as) {
193
194
            auto to_pointer = [&](auto&& arg, std::size_t offset_byte = 0) {
                return to_rocblas_type(as.from(arg.data() + offset_byte));
195
196
197
198
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
199
                hipMemcpy(to_pointer(result, i * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
200
201
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
202
                          hipMemcpyDeviceToDevice);
203
204
205
            }
        });
    }
Shucai Xiao's avatar
Shucai Xiao committed
206
    else if(c_lens.size() == 1 || (c_lens.size() == 2 && c_lens[1] == out_lens[1]))
207
208
209
210
    {
        auto m = out_lens[0];
        auto n = out_lens[1];
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
211
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
212
213
214
215
216
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < m; ++i)
            {
217
                hipMemcpy(to_pointer(result, i * n * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
218
219
                          to_pointer(c),
                          c.get_shape().bytes(),
Shucai Xiao's avatar
Shucai Xiao committed
220
                          hipMemcpyDeviceToDevice);
221
222
223
224
225
226
227
228
229
230
231
232
233
            }
        });
    }
    // case of c_lens.size() == 2 && c_len[0] == out_lens[0]
    else
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg, std::size_t offset) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };

            for(std::size_t i = 0; i < output_shape.elements(); ++i)
            {
234
235
236
                hipMemcpy(to_pointer(result, i * type_size),
                          to_pointer(c, i / out_lens[1] * type_size),
                          type_size,
Shucai Xiao's avatar
Shucai Xiao committed
237
                          hipMemcpyDeviceToDevice);
238
239
240
            }
        });
    }
241
242
}

wsttiger's avatar
wsttiger committed
243
244
245
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
246
{
247
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
248
    if(is_3inputs)
249
    {
Shucai Xiao's avatar
Shucai Xiao committed
250
        fill_result(output_shape, args[3], args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
251

252
253
254
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(op.beta));
Shucai Xiao's avatar
Shucai Xiao committed
255
256
257
258
            bool transa     = args[0].get_shape().transposed();
            bool transb     = args[1].get_shape().transposed();
            rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
            rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
259
            rocblas_int ldc = args[3].get_shape().strides()[0];
Shucai Xiao's avatar
Shucai Xiao committed
260
261
262
263
            auto out_lens   = output_shape.lens();
            rocblas_int m   = out_lens[0];
            rocblas_int n   = out_lens[1];
            rocblas_int k   = args[0].get_shape().lens()[1];
264
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
265
266
267
            auto cpu_a      = migraphx::gpu::from_gpu(args[0]);
            auto cpu_b      = migraphx::gpu::from_gpu(args[1]);
            auto cpu_res    = migraphx::gpu::from_gpu(args[3]);
268
269
270
271
272
273
274
275
276
277
278
279
            std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl;
            std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl;
            std::cout << "gpu::gemm, cpu_res = " << cpu_res << std::endl;
            std::cout << "gpu::gemm, transb = " << transb << std::endl;
            std::cout << "gpu::gemm, transa = " << transb << std::endl;
            std::cout << "gpu::gemm, m = " << m << std::endl;
            std::cout << "gpu::gemm, n = " << n << std::endl;
            std::cout << "gpu::gemm, k = " << k << std::endl;
            std::cout << "gpu::gemm, lda = " << lda << std::endl;
            std::cout << "gpu::gemm, ldb = " << ldb << std::endl;
            std::cout << "gpu::gemm, ldc = " << ldc << std::endl;

280
281
282
283
284
285
286
287
288
289
290
291
292
            generic_rocblas_gemm(as,
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 n,
                                 m,
                                 k,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
                                 to_pointer(args[0]),
                                 lda,
                                 &beta_r,
293
                                 to_pointer(args[3]),
294
295
296
297
298
299
                                 ldc);

        });

        return args[3];
    }
300

301
302
    // 2 input arguments cases
    // vector inner product
303
    std::size_t type_size = output_shape.type_size();
Shucai Xiao's avatar
Shucai Xiao committed
304
    if(output_shape.elements() == 1)
305
    {
306
        assert(args[0].get_shape().elements() == args[1].get_shape().elements());
307
308
309
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
310
311
312
            generic_rocblas_dot(as,
                                ctx.get_stream().get_rocblas(),
                                args[1].get_shape().elements(),
313
314
315
316
                                to_pointer(args[0]),
                                1,
                                to_pointer(args[1]),
                                1,
317
                                to_pointer(args[2]));
318

Shucai Xiao's avatar
Shucai Xiao committed
319
            generic_rocblas_scal(
Shucai Xiao's avatar
Shucai Xiao committed
320
                as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]), 1);
321
322
        });
    }
323
    // matrix * vector
Shucai Xiao's avatar
Shucai Xiao committed
324
    else if(args[1].get_shape().lens().size() == 1)
325
    {
Shucai Xiao's avatar
Shucai Xiao committed
326
        auto a_lens       = args[0].get_shape().lens();
327
328
329
        std::size_t dim_0 = a_lens.size() - 2;
        std::size_t dim_1 = a_lens.size() - 1;
        bool trans        = args[0].get_shape().transposed();
Shucai Xiao's avatar
Shucai Xiao committed
330
331
332
333
        rocblas_int m     = a_lens[trans ? dim_1 : dim_0];
        rocblas_int n     = a_lens[trans ? dim_0 : dim_1];
        float beta        = 0.0f;
        rocblas_int lda   = args[0].get_shape().strides()[trans ? dim_1 : dim_0];
334
335

        assert(a_lens.back() == args[1].get_shape().elements());
Shucai Xiao's avatar
Shucai Xiao committed
336
337
        std::size_t batch_num = std::accumulate(
            a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
338
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
339
340
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
341
342
343
344
            auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                return to_rocblas_type(as.from(arg.data() + offset));
            };
            for(std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
345
            {
346
347
348
349
350
351
                generic_rocblas_gemv(as,
                                     ctx.get_stream().get_rocblas(),
                                     trans ? rocblas_operation_transpose : rocblas_operation_none,
                                     m,
                                     n,
                                     &alpha_r,
352
                                     to_pointer(args[0], batch_no * m * n * type_size),
353
354
355
356
                                     lda,
                                     to_pointer(args[1]),
                                     1,
                                     &beta_r,
357
                                     to_pointer(args[2], batch_no * n * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
358
                                     1);
359
360
361
            }
        });
    }
362
    // vector * matrix
Shucai Xiao's avatar
Shucai Xiao committed
363
    else if(args[0].get_shape().lens().size() == 1)
364
    {
365
        auto a_lens       = args[0].get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
366
        auto b_lens       = args[1].get_shape().lens();
367
368
        std::size_t dim_0 = b_lens.size() - 2;
        std::size_t dim_1 = b_lens.size() - 1;
369
370
371
372
373
        bool transb       = args[1].get_shape().transposed();
        bool transa       = false;
        rocblas_int lda   = a_lens[0];
        rocblas_int ldb   = args[1].get_shape().strides()[(transb ? dim_1 : dim_0)];
        rocblas_int ldc   = b_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
374
375
376
        rocblas_int m     = 1;
        rocblas_int n     = args[1].get_shape().lens()[dim_1];
        rocblas_int k     = a_lens[0];
Shucai Xiao's avatar
Shucai Xiao committed
377
        float beta        = 0.0f;
378
        assert(b_lens[dim_0] == args[0].get_shape().elements());
379

Shucai Xiao's avatar
Shucai Xiao committed
380
381
        std::size_t batch_num = std::accumulate(
            b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
382

Shucai Xiao's avatar
Shucai Xiao committed
383
384
        auto cpu_a   = migraphx::gpu::from_gpu(args[0]);
        auto cpu_b   = migraphx::gpu::from_gpu(args[1]);
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        auto cpu_res = migraphx::gpu::from_gpu(args[2]);
        std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl;
        std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl;
        std::cout << "gpu::gemm, cpu_res = " << cpu_res << std::endl;
        std::cout << "gpu::gemm, transb = " << transb << std::endl;
        std::cout << "gpu::gemm, transa = " << transb << std::endl;
        std::cout << "gpu::gemm, m = " << m << std::endl;
        std::cout << "gpu::gemm, n = " << n << std::endl;
        std::cout << "gpu::gemm, k = " << k << std::endl;
        std::cout << "gpu::gemm, lda = " << lda << std::endl;
        std::cout << "gpu::gemm, ldb = " << ldb << std::endl;
        std::cout << "gpu::gemm, ldc = " << ldc << std::endl;

        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
        });

403
        output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
404
405
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
406
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
407
408

            generic_rocblas_batched_gemm(
Shucai Xiao's avatar
Shucai Xiao committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                0,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                m * n,
                batch_num);
428
        });
429
430

        return args[2];
431
432
433
434
    }
    // (batch) matrix multiplication
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
435
436
437
438
        bool transa   = args[0].get_shape().transposed();
        bool transb   = args[1].get_shape().transposed();
        auto a_lens   = args[0].get_shape().lens();
        auto b_lens   = args[1].get_shape().lens();
439
440
        auto out_lens = output_shape.lens();

Shucai Xiao's avatar
Shucai Xiao committed
441
442
443
444
445
446
447
448
        rocblas_int lda =
            args[0].get_shape().strides()[transa ? a_lens.size() - 1 : a_lens.size() - 2];
        rocblas_int ldb =
            args[1].get_shape().strides()[transb ? b_lens.size() - 1 : b_lens.size() - 2];
        rocblas_int ldc = args[2].get_shape().strides()[out_lens.size() - 2];
        rocblas_int m   = out_lens[out_lens.size() - 2];
        rocblas_int n   = out_lens[out_lens.size() - 1];
        rocblas_int k   = args[0].get_shape().lens()[a_lens.size() - 1];
Shucai Xiao's avatar
Shucai Xiao committed
449
        float beta      = 0.0f;
450
451
        auto input_dims = std::min(a_lens.size(), b_lens.size());
        std::size_t axis{0};
Shucai Xiao's avatar
Shucai Xiao committed
452
        for(axis = 2; axis < input_dims; ++axis)
453
        {
Shucai Xiao's avatar
Shucai Xiao committed
454
            if(a_lens[a_lens.size() - axis] != b_lens[b_lens.size() - axis])
455
456
457
458
459
460
            {
                break;
            }
        }

        // The number of matrices that can be computed in one call
Shucai Xiao's avatar
Shucai Xiao committed
461
        // batch_num > 1, we need to call the batch_gemm function,
462
        // otherwise, call the gemm function directly
Shucai Xiao's avatar
Shucai Xiao committed
463
464
465
466
467
        std::size_t num_matrices =
            std::accumulate(a_lens.rbegin() + 2,
                            (axis == a_lens.size() ? a_lens.rend() : a_lens.rbegin() + axis),
                            std::size_t{1},
                            std::multiplies<std::size_t>());
468
469
        std::size_t a_len_diff = out_lens.size() - a_lens.size();
        std::size_t b_len_diff = out_lens.size() - b_lens.size();
Shucai Xiao's avatar
Shucai Xiao committed
470
471
472
473
474
475
        std::vector<std::size_t> a_batch_lens(a_lens.begin(),
                                              a_lens.begin() + a_lens.size() - axis);
        std::vector<std::size_t> b_batch_lens(b_lens.begin(),
                                              b_lens.begin() + b_lens.size() - axis);
        std::vector<std::size_t> out_batch_lens(out_lens.begin(),
                                                out_lens.begin() + out_lens.size() - axis);
476
477
478
479

        shape::type_t t = output_shape.type();
        shape a_batch_shape{t, a_batch_lens};
        shape b_batch_shape{t, b_batch_lens};
Shucai Xiao's avatar
Shucai Xiao committed
480
        shape out_batch_shape{t, out_batch_lens};
481

Shucai Xiao's avatar
Shucai Xiao committed
482
        shape_for_each(out_batch_shape, [&](auto out_idx) {
483
484
485
            std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
            std::vector<std::size_t> a_idx(a_lens.size() - axis);
            std::vector<std::size_t> b_idx(b_lens.size() - axis);
Shucai Xiao's avatar
Shucai Xiao committed
486
487
488
489
490
491
492
493
494
495
            std::transform(out_idx.begin() + a_len_diff,
                           out_idx.end(),
                           a_batch_lens.begin(),
                           a_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
            std::transform(out_idx.begin() + b_len_diff,
                           out_idx.end(),
                           b_batch_lens.begin(),
                           b_idx.begin(),
                           [&](auto i, auto j) { return (j == 1) ? 0 : i; });
496
497
498
499
500

            std::size_t a_ind = a_batch_shape.index(a_idx.begin(), b_idx.end());
            std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());

            output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
501
502
                auto alpha_r    = to_rocblas_type(as(op.alpha));
                auto beta_r     = to_rocblas_type(as(beta));
Shucai Xiao's avatar
Shucai Xiao committed
503
504
505
                auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
                    return to_rocblas_type(as.from(arg.data() + offset));
                };
Shucai Xiao's avatar
Shucai Xiao committed
506
                if(num_matrices > 1)
507
508
509
510
511
512
513
514
515
516
                {
                    generic_rocblas_batched_gemm(
                        as,
                        ctx.get_stream().get_rocblas(),
                        transb ? rocblas_operation_transpose : rocblas_operation_none,
                        transa ? rocblas_operation_transpose : rocblas_operation_none,
                        n,
                        m,
                        k,
                        &alpha_r,
517
                        to_pointer(args[1], k * n * num_matrices * b_ind * type_size),
518
519
                        ldb,
                        k * n,
520
                        to_pointer(args[0], m * k * num_matrices * a_ind * type_size),
521
522
523
                        lda,
                        m * k,
                        &beta_r,
524
                        to_pointer(args[2], m * n * num_matrices * out_ind * type_size),
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
                        ldc,
                        m * n,
                        num_matrices);
                }
                // num_matrices per call is 1
                else
                {
                    generic_rocblas_gemm(
                        as,
                        ctx.get_stream().get_rocblas(),
                        transb ? rocblas_operation_transpose : rocblas_operation_none,
                        transa ? rocblas_operation_transpose : rocblas_operation_none,
                        n,
                        m,
                        k,
                        &alpha_r,
541
                        to_pointer(args[1], k * n * b_ind * type_size),
542
                        ldb,
543
                        to_pointer(args[0], m * k * a_ind * type_size),
544
545
                        lda,
                        &beta_r,
546
                        to_pointer(args[2], m * n * out_ind * type_size),
Shucai Xiao's avatar
Shucai Xiao committed
547
                        ldc);
548
                }
Shucai Xiao's avatar
Shucai Xiao committed
549

550
551
552
            });
        });
    }
553

554
    return args[2];
wsttiger's avatar
wsttiger committed
555
556
557
}

} // namespace gpu
Paul's avatar
Paul committed
558
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
559
} // namespace migraphx