gemm.cpp 9.79 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
wsttiger's avatar
wsttiger committed
3

Paul's avatar
Paul committed
4
namespace migraphx {
Paul's avatar
Paul committed
5
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
6
7
namespace gpu {

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
template <class... Ts>
void generic_rocblas_scal(shape::as<float>, Ts&&... xs)
{
    rocblas_sscal(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_scal(shape::as<double>, Ts&&... xs)
{
    rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    rocblas_haxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
    rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    rocblas_daxpy(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
void generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
void generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
110
template <class... Ts>
Paul's avatar
Paul committed
111
112
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
Paul's avatar
Paul committed
113
    rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
114
115
}

Paul's avatar
Paul committed
116
template <class... Ts>
Paul's avatar
Paul committed
117
118
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
Paul's avatar
Paul committed
119
    rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
120
121
}

Paul's avatar
Paul committed
122
template <class... Ts>
Paul's avatar
Paul committed
123
124
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
Paul's avatar
Paul committed
125
    rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
template <class T, class... Ts>
Paul's avatar
Paul committed
129
130
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
131
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
132
133
}

Paul's avatar
Paul committed
134
template <class T>
Paul's avatar
Paul committed
135
136
137
138
139
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
140
template <class T>
Paul's avatar
Paul committed
141
142
143
144
145
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
146
template <>
Paul's avatar
Paul committed
147
148
149
150
151
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
152
template <class T>
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
161
template <class T>
Paul's avatar
Paul committed
162
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
163
{
Paul's avatar
Paul committed
164
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
165
166
}

Paul's avatar
Paul committed
167
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
168

wsttiger's avatar
wsttiger committed
169
170
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
171
    return op.compute_shape(inputs);
wsttiger's avatar
wsttiger committed
172
}
173
174
175
176
177
178
179

std::size_t miopen_gemm::compute_offset(std::vector<std::size_t>& out_lens, 
    std::size_t index, std::vector<std::size_t> &data_lens) const
{
    
}

wsttiger's avatar
wsttiger committed
180
181
182
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
183
{
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    bool is_3inputs = (args.size() == 4);

    if (output_shape.elements() == 1)
    {
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(op.beta));
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
            generic_rocblas_dot(as, ctx.get_stream().get_rocblas(),
                                args[1].get_shape().elements(), 
                                to_pointer(args[0]),
                                1,
                                to_pointer(args[1]),
                                1,
                                is_3inputs ? to_pointer(args[3]): to_pointer(args[2]));

            generic_rocblas_scal(as, ctx.get_stream().get_rocblas(),
                                    1, 
                                    &alpha_r, 
                                    is_3inputs ? to_pointer(args[3]): to_pointer(args[2]));
                                    1);

            if (is_3inputs)
            {

                generic_rocblas_axpy(as, ctx.get_stream().get_rocblas(),
                                     1,
                                     &beta_r,
                                     to_pointer(args[2]),
                                     1,
                                     to_pointer(args[3]),
                                     1);
            }
        });

        return is_3inputs ? args[3] : args[2];
    }

    // b is a vector, so the computation is matrix * vector
    // could not be the case of inner product of vectors since 
    // it is already processed above
    if (args[1].get_shape().lens().size() == 1)
    {
        // considering the batch input, so A could be a batch
        // of matrices
        auto a_lens = args[0].get_shape().lens();
        std::size_t n_dims = a_lens.size();
        std::size_t dim_0  = n_dims - 2;
        std::size_t dim_1  = n_dims - 1;
        bool transa        = args[0].get_shape().transposed();
        rocblas_int lda    = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
        rocblas_int m      = a_lens[dim_0];
        rocblas_int k      = a_lens[dim_1];
        auto batch_num     = std::accumulate(
            a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
        
        output_shape.visit_type([&](auto as) {
            auto alpha_r    = to_rocblas_type(as(op.alpha));
            auto beta_r     = to_rocblas_type(as(op.beta));
            auto to_pointer = [&](auto&& arg, std::size_t offset) { return to_rocblas_type(as.from(arg.data() + offset)); };
            for (std::size_t batch_no = 0; batch_no < batch_num; ++batch_no)
            {
                if(is_3inputs)
                    hipMemcpy(to_pointer(args[3] + batch_no * m),
                            to_pointer(args[2]),
                            output_shape.bytes(),
                            hipMemcpyDeviceToDevice);
                else
                    hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
257
258
    bool transa        = args[0].get_shape().transposed();
    bool transb        = args[1].get_shape().transposed();
259
    std::size_t n_dims = args[0].get_shape().lens().size();
Shucai Xiao's avatar
Shucai Xiao committed
260
261
262
263
264
    std::size_t dim_0  = n_dims - 2;
    std::size_t dim_1  = n_dims - 1;
    rocblas_int lda    = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb    = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
    rocblas_int ldc    = args[2].get_shape().strides()[dim_0];
Shucai Xiao's avatar
Shucai Xiao committed
265
    auto out_lens      = output_shape.lens();
266
267
    rocblas_int m      = out_lens[dim_0];
    rocblas_int n      = out_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
268
    rocblas_int k      = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
269
270
    auto batch_num     = std::accumulate(
        out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
271
272

    bool is_3inputs = (args.size() == 4);
273
274
275
276
277
    // two input arguments
    if (!is_3inputs)
    {
    }

278
    output_shape.visit_type([&](auto as) {
279
        auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); };
Shucai Xiao's avatar
Shucai Xiao committed
280
281
282
283
284
        if(is_3inputs)
            hipMemcpy(to_pointer(args[3]),
                      to_pointer(args[2]),
                      output_shape.bytes(),
                      hipMemcpyDeviceToDevice);
285
286
287
288
        else
            hipMemset(to_pointer(args[2]), 0, output_shape.bytes());
    });

Paul's avatar
Paul committed
289
    output_shape.visit_type([&](auto as) {
290
        auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { return to_rocblas_type(as.from(arg.data() + offset)); };
Shucai Xiao's avatar
Shucai Xiao committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        generic_rocblas_batched_gemm(as,
                                     ctx.get_stream().get_rocblas(),
                                     transb ? rocblas_operation_transpose : rocblas_operation_none,
                                     transa ? rocblas_operation_transpose : rocblas_operation_none,
                                     n,
                                     m,
                                     k,
                                     &alpha_r,
                                     to_pointer(args[1]),
                                     ldb,
                                     k * n,
                                     to_pointer(args[0]),
                                     lda,
                                     m * k,
                                     &beta_r,
                                     to_pointer(args[2]),
                                     ldc,
                                     m * n,
                                     batch_num);
Paul's avatar
Paul committed
310
    });
311

312
    return (is_3inputs ? args[3] : args[2]);
wsttiger's avatar
wsttiger committed
313
314
315
}

} // namespace gpu
Paul's avatar
Paul committed
316
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
317
} // namespace migraphx