"vscode:/vscode.git/clone" did not exist on "d1e0225d23747beb9d4a77a0b8d5cfaf57a74eae"
gemm.cpp 13.1 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
9
rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
10
{
Shucai Xiao's avatar
Shucai Xiao committed
11
    return rocblas_sscal(std::forward<Ts>(xs)...);
12
13
14
}

template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
15
rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
16
{
Shucai Xiao's avatar
Shucai Xiao committed
17
18
19
20
21
22
23
24
25
26
27
28
29
    return rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    return rocblas_haxpy(std::forward<Ts>(xs)...);
30
31
32
}

template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
33
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
34
{
Shucai Xiao's avatar
Shucai Xiao committed
35
36
37
38
39
40
41
    return rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    return rocblas_daxpy(std::forward<Ts>(xs)...);
42
43
44
}

template <class T, class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    return rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    return rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    return rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    return rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
106
107
108
109
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
110
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
111
rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
Paul's avatar
Paul committed
112
{
Shucai Xiao's avatar
Shucai Xiao committed
113
    return rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
114
115
}

Paul's avatar
Paul committed
116
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
117
rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
Paul's avatar
Paul committed
118
{
Shucai Xiao's avatar
Shucai Xiao committed
119
    return rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
120
121
}

Paul's avatar
Paul committed
122
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
123
rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
Paul's avatar
Paul committed
124
{
Shucai Xiao's avatar
Shucai Xiao committed
125
    return rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
template <class T, class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
129
rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
Paul's avatar
Paul committed
130
{
131
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
template <class T>
Paul's avatar
Paul committed
135
136
137
138
139
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
146
template <>
Paul's avatar
Paul committed
147
148
149
150
151
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
152
template <class T>
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
161
template <class T>
Paul's avatar
Paul committed
162
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
163
{
Paul's avatar
Paul committed
164
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
165
166
}

Paul's avatar
Paul committed
167
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
168

wsttiger's avatar
wsttiger committed
169
170
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
Shucai Xiao's avatar
Shucai Xiao committed
171
172
173
174
175
176
177
    std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
    if(input_shapes.size() == 3)
    {
        auto c_shape = inputs[2];
        check_shapes{{c_shape}}.not_broadcasted();
    }
    return op.compute_shape(input_shapes);
wsttiger's avatar
wsttiger committed
178
}
Shucai Xiao's avatar
Shucai Xiao committed
179

wsttiger's avatar
wsttiger committed
180
181
182
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
183
{
Shucai Xiao's avatar
Shucai Xiao committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    bool is_3inputs = (args.size() == 4);
    if(is_3inputs)
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            hipMemcpy(to_pointer(args[3]),
                      to_pointer(args[2]),
                      output_shape.bytes(),
                      hipMemcpyDeviceToDevice);
        });

        output_shape.visit_type([&](auto as) {
            auto n_dim        = output_shape.lens().size();
            auto dim_1        = n_dim - 1;
            auto dim_0        = n_dim - 2;
            auto alpha_r      = to_rocblas_type(as(op.alpha));
            auto beta_r       = to_rocblas_type(as(op.beta));
            bool transa       = args[0].get_shape().transposed();
            bool transb       = args[1].get_shape().transposed();
            rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
            rocblas_int ldb   = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
            rocblas_int ldc   = args[3].get_shape().strides()[dim_0];
            auto out_lens     = output_shape.lens();
            rocblas_int m     = out_lens[dim_0];
            rocblas_int n     = out_lens[dim_1];
            rocblas_int k     = args[0].get_shape().lens()[dim_1];
            auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
                                                out_lens.rend(),
                                                std::size_t{1},
                                                std::multiplies<std::size_t>());
            auto to_pointer   = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
                to_pointer(args[3]),
                ldc,
                m * n,
                num_matrices);
        });

        return args[3];
    }

    // 2 input argument cases
    // vector inner product
    auto a_lens = args[0].get_shape().lens();
    auto b_lens = args[1].get_shape().lens();
    if(output_shape.elements() == 1)
    {
        assert(args[0].get_shape().elements() == args[1].get_shape().elements());
        float beta           = 0.0f;
        rocblas_int elem_num = static_cast<rocblas_int>(args[0].get_shape().elements());
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            // the function generic_rocblas_dot is not stable, so have to
            // call the gemm function instead. In the future, we may change
            // to call generic_rocblas_dot when it is stable.
            generic_rocblas_gemm(as,
                                 ctx.get_stream().get_rocblas(),
                                 rocblas_operation_none,
                                 rocblas_operation_none,
                                 1,
                                 1,
                                 elem_num,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 1,
                                 to_pointer(args[0]),
                                 elem_num,
                                 &beta_r,
                                 to_pointer(args[2]),
                                 1);

        });
    }
    // matrix * vector (b is a vector)
    else if(b_lens.size() == 2 && b_lens.at(1) == 1)
    {
Shucai Xiao's avatar
Shucai Xiao committed
277
278
279
280
281
        bool transa     = args[0].get_shape().transposed();
        rocblas_int m   = static_cast<rocblas_int>(a_lens[0]);
        rocblas_int n   = static_cast<rocblas_int>(a_lens[1]);
        rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
        float beta      = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
282
        assert(a_lens.back() == args[1].get_shape().elements());
283

Shucai Xiao's avatar
Shucai Xiao committed
284
285
286
287
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
288
289
290
291
292
293
294
295
296
297
298
299
300
            generic_rocblas_gemv(as,
                                 ctx.get_stream().get_rocblas(),
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 m,
                                 n,
                                 &alpha_r,
                                 to_pointer(args[0]),
                                 lda,
                                 to_pointer(args[1]),
                                 1,
                                 &beta_r,
                                 to_pointer(args[2]),
                                 1);
Shucai Xiao's avatar
Shucai Xiao committed
301
302
303
304
305
        });
    }
    // vector * matrix (a is a vector)
    else if(a_lens.size() == 2 && a_lens.at(0) == 1)
    {
Shucai Xiao's avatar
Shucai Xiao committed
306
307
308
309
310
        bool transb     = !args[1].get_shape().transposed();
        rocblas_int ldb = args[1].get_shape().strides()[(transb ? 1 : 0)];
        rocblas_int m   = b_lens[0];
        rocblas_int n   = b_lens[1];
        float beta      = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
311
312
313
314
315
        assert(b_lens[0] == args[0].get_shape().elements());
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
316
317
318
319
320
321
322
323
324
325
326
327
328
            generic_rocblas_gemv(as,
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 m,
                                 n,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
                                 to_pointer(args[0]),
                                 1,
                                 &beta_r,
                                 to_pointer(args[2]),
                                 1);
Shucai Xiao's avatar
Shucai Xiao committed
329
330
331
332
333
334
335
336
337
        });
    }
    // batch matrix multiplication
    else
    {
        output_shape.visit_type([&](auto as) {
            auto n_dim        = output_shape.lens().size();
            auto dim_1        = n_dim - 1;
            auto dim_0        = n_dim - 2;
Shucai Xiao's avatar
Shucai Xiao committed
338
            float beta        = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
339
340
341
342
343
344
            auto alpha_r      = to_rocblas_type(as(op.alpha));
            auto beta_r       = to_rocblas_type(as(beta));
            bool transa       = args[0].get_shape().transposed();
            bool transb       = args[1].get_shape().transposed();
            rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
            rocblas_int ldb   = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
345
            rocblas_int ldc   = args[2].get_shape().strides()[dim_0];
Shucai Xiao's avatar
Shucai Xiao committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
            auto out_lens     = output_shape.lens();
            rocblas_int m     = out_lens[dim_0];
            rocblas_int n     = out_lens[dim_1];
            rocblas_int k     = args[0].get_shape().lens()[dim_1];
            auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
                                                out_lens.rend(),
                                                std::size_t{1},
                                                std::multiplies<std::size_t>());
            auto to_pointer   = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                m * n,
                num_matrices);
Shucai Xiao's avatar
Shucai Xiao committed
375
        });
Shucai Xiao's avatar
Shucai Xiao committed
376
    }
Shucai Xiao's avatar
Shucai Xiao committed
377

wsttiger's avatar
wsttiger committed
378
379
380
381
    return args[2];
}

} // namespace gpu
Paul's avatar
Paul committed
382
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
383
} // namespace migraphx