gemm.cpp 13.2 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
3
#include <migraphx/gpu/device/add.hpp>
wsttiger's avatar
wsttiger committed
4

Paul's avatar
Paul committed
5
namespace migraphx {
Paul's avatar
Paul committed
6
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
7
8
namespace gpu {

9
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
10
rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
11
{
Shucai Xiao's avatar
Shucai Xiao committed
12
    return rocblas_sscal(std::forward<Ts>(xs)...);
13
14
15
}

template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
16
rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
17
{
Shucai Xiao's avatar
Shucai Xiao committed
18
19
20
21
22
23
24
25
26
27
28
29
30
    return rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    return rocblas_haxpy(std::forward<Ts>(xs)...);
31
32
33
}

template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
34
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
35
{
Shucai Xiao's avatar
Shucai Xiao committed
36
37
38
39
40
41
42
    return rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    return rocblas_daxpy(std::forward<Ts>(xs)...);
43
44
45
}

template <class T, class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    return rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    return rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    return rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    return rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
107
108
109
110
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
111
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
112
rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
Paul's avatar
Paul committed
113
{
Shucai Xiao's avatar
Shucai Xiao committed
114
    return rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
115
116
}

Paul's avatar
Paul committed
117
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
118
rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
Paul's avatar
Paul committed
119
{
Shucai Xiao's avatar
Shucai Xiao committed
120
    return rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
121
122
}

Paul's avatar
Paul committed
123
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
124
rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
Paul's avatar
Paul committed
125
{
Shucai Xiao's avatar
Shucai Xiao committed
126
    return rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
127
128
}

Paul's avatar
Paul committed
129
template <class T, class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
130
rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
Paul's avatar
Paul committed
131
{
132
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
133
134
}

Paul's avatar
Paul committed
135
template <class T>
Paul's avatar
Paul committed
136
137
138
139
140
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
141
template <class T>
Paul's avatar
Paul committed
142
143
144
145
146
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
147
template <>
Paul's avatar
Paul committed
148
149
150
151
152
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
153
template <class T>
Paul's avatar
Paul committed
154
155
156
157
158
159
160
161
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
162
template <class T>
Paul's avatar
Paul committed
163
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
164
{
Paul's avatar
Paul committed
165
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
166
167
}

Paul's avatar
Paul committed
168
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
169

wsttiger's avatar
wsttiger committed
170
171
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
Shucai Xiao's avatar
Shucai Xiao committed
172
173
174
175
176
177
178
    std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
    if(input_shapes.size() == 3)
    {
        auto c_shape = inputs[2];
        check_shapes{{c_shape}}.not_broadcasted();
    }
    return op.compute_shape(input_shapes);
wsttiger's avatar
wsttiger committed
179
}
Shucai Xiao's avatar
Shucai Xiao committed
180

wsttiger's avatar
wsttiger committed
181
182
183
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
184
{
Shucai Xiao's avatar
Shucai Xiao committed
185
186
187
188
189
    bool is_3inputs = (args.size() == 4);
    if(is_3inputs)
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
190
            hipMemcpyAsync(to_pointer(args[3]),
Shucai Xiao's avatar
Shucai Xiao committed
191
192
                      to_pointer(args[2]),
                      output_shape.bytes(),
193
194
                      hipMemcpyDeviceToDevice,
                      ctx.get_stream().get());
Shucai Xiao's avatar
Shucai Xiao committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        });

        output_shape.visit_type([&](auto as) {
            auto n_dim        = output_shape.lens().size();
            auto dim_1        = n_dim - 1;
            auto dim_0        = n_dim - 2;
            auto alpha_r      = to_rocblas_type(as(op.alpha));
            auto beta_r       = to_rocblas_type(as(op.beta));
            bool transa       = args[0].get_shape().transposed();
            bool transb       = args[1].get_shape().transposed();
            rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
            rocblas_int ldb   = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
            rocblas_int ldc   = args[3].get_shape().strides()[dim_0];
            auto out_lens     = output_shape.lens();
            rocblas_int m     = out_lens[dim_0];
            rocblas_int n     = out_lens[dim_1];
            rocblas_int k     = args[0].get_shape().lens()[dim_1];
            auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
                                                out_lens.rend(),
                                                std::size_t{1},
                                                std::multiplies<std::size_t>());
            auto to_pointer   = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
                to_pointer(args[3]),
                ldc,
                m * n,
                num_matrices);
        });
238
        //device::add(ctx.get_stream().get(), args[3], args[2], args[3]);
Shucai Xiao's avatar
Shucai Xiao committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

        return args[3];
    }

    // 2 input argument cases
    // vector inner product
    auto a_lens = args[0].get_shape().lens();
    auto b_lens = args[1].get_shape().lens();
    if(output_shape.elements() == 1)
    {
        assert(args[0].get_shape().elements() == args[1].get_shape().elements());
        float beta           = 0.0f;
        rocblas_int elem_num = static_cast<rocblas_int>(args[0].get_shape().elements());
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            // the function generic_rocblas_dot is not stable, so have to
            // call the gemm function instead. In the future, we may change
            // to call generic_rocblas_dot when it is stable.
            generic_rocblas_gemm(as,
                                 ctx.get_stream().get_rocblas(),
                                 rocblas_operation_none,
                                 rocblas_operation_none,
                                 1,
                                 1,
                                 elem_num,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 1,
                                 to_pointer(args[0]),
                                 elem_num,
                                 &beta_r,
                                 to_pointer(args[2]),
                                 1);

        });
    }
    // matrix * vector (b is a vector)
    else if(b_lens.size() == 2 && b_lens.at(1) == 1)
    {
Shucai Xiao's avatar
Shucai Xiao committed
280
281
282
283
284
        bool transa     = args[0].get_shape().transposed();
        rocblas_int m   = static_cast<rocblas_int>(a_lens[0]);
        rocblas_int n   = static_cast<rocblas_int>(a_lens[1]);
        rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
        float beta      = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
285
        assert(a_lens.back() == args[1].get_shape().elements());
286

Shucai Xiao's avatar
Shucai Xiao committed
287
288
289
290
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
291
292
293
294
295
296
297
298
299
300
301
302
303
            generic_rocblas_gemv(as,
                                 ctx.get_stream().get_rocblas(),
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 m,
                                 n,
                                 &alpha_r,
                                 to_pointer(args[0]),
                                 lda,
                                 to_pointer(args[1]),
                                 1,
                                 &beta_r,
                                 to_pointer(args[2]),
                                 1);
Shucai Xiao's avatar
Shucai Xiao committed
304
305
306
307
308
        });
    }
    // vector * matrix (a is a vector)
    else if(a_lens.size() == 2 && a_lens.at(0) == 1)
    {
Shucai Xiao's avatar
Shucai Xiao committed
309
310
311
312
313
        bool transb     = !args[1].get_shape().transposed();
        rocblas_int ldb = args[1].get_shape().strides()[(transb ? 1 : 0)];
        rocblas_int m   = b_lens[0];
        rocblas_int n   = b_lens[1];
        float beta      = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
314
315
316
317
318
        assert(b_lens[0] == args[0].get_shape().elements());
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(beta));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
Shucai Xiao's avatar
Shucai Xiao committed
319
320
321
322
323
324
325
326
327
328
329
330
331
            generic_rocblas_gemv(as,
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 m,
                                 n,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
                                 to_pointer(args[0]),
                                 1,
                                 &beta_r,
                                 to_pointer(args[2]),
                                 1);
Shucai Xiao's avatar
Shucai Xiao committed
332
333
334
335
336
337
338
339
340
        });
    }
    // batch matrix multiplication
    else
    {
        output_shape.visit_type([&](auto as) {
            auto n_dim        = output_shape.lens().size();
            auto dim_1        = n_dim - 1;
            auto dim_0        = n_dim - 2;
Shucai Xiao's avatar
Shucai Xiao committed
341
            float beta        = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
342
343
344
345
346
347
            auto alpha_r      = to_rocblas_type(as(op.alpha));
            auto beta_r       = to_rocblas_type(as(beta));
            bool transa       = args[0].get_shape().transposed();
            bool transb       = args[1].get_shape().transposed();
            rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
            rocblas_int ldb   = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
348
            rocblas_int ldc   = args[2].get_shape().strides()[dim_0];
Shucai Xiao's avatar
Shucai Xiao committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
            auto out_lens     = output_shape.lens();
            rocblas_int m     = out_lens[dim_0];
            rocblas_int n     = out_lens[dim_1];
            rocblas_int k     = args[0].get_shape().lens()[dim_1];
            auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
                                                out_lens.rend(),
                                                std::size_t{1},
                                                std::multiplies<std::size_t>());
            auto to_pointer   = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
                to_pointer(args[2]),
                ldc,
                m * n,
                num_matrices);
Shucai Xiao's avatar
Shucai Xiao committed
378
        });
Shucai Xiao's avatar
Shucai Xiao committed
379
    }
Shucai Xiao's avatar
Shucai Xiao committed
380

wsttiger's avatar
wsttiger committed
381
382
383
384
    return args[2];
}

} // namespace gpu
Paul's avatar
Paul committed
385
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
386
} // namespace migraphx