gemm.cpp 8.82 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
3
#include <migraphx/gpu/device/add.hpp>
wsttiger's avatar
wsttiger committed
4

Paul's avatar
Paul committed
5
namespace migraphx {
Paul's avatar
Paul committed
6
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
7
8
namespace gpu {

9
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
10
rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
11
{
Shucai Xiao's avatar
Shucai Xiao committed
12
    return rocblas_sscal(std::forward<Ts>(xs)...);
13
14
15
}

template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
16
rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
17
{
Shucai Xiao's avatar
Shucai Xiao committed
18
19
20
21
22
23
24
25
26
27
28
29
30
    return rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    return rocblas_haxpy(std::forward<Ts>(xs)...);
31
32
33
}

template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
34
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
35
{
Shucai Xiao's avatar
Shucai Xiao committed
36
37
38
39
40
41
42
    return rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    return rocblas_daxpy(std::forward<Ts>(xs)...);
43
44
45
}

template <class T, class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    return rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    return rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    return rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    return rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
107
108
109
110
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
111
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
112
rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
Paul's avatar
Paul committed
113
{
Shucai Xiao's avatar
Shucai Xiao committed
114
    return rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
115
116
}

Paul's avatar
Paul committed
117
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
118
rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
Paul's avatar
Paul committed
119
{
Shucai Xiao's avatar
Shucai Xiao committed
120
    return rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
121
122
}

Paul's avatar
Paul committed
123
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
124
rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
Paul's avatar
Paul committed
125
{
Shucai Xiao's avatar
Shucai Xiao committed
126
    return rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
127
128
}

Paul's avatar
Paul committed
129
template <class T, class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
130
rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
Paul's avatar
Paul committed
131
{
132
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
133
134
}

Paul's avatar
Paul committed
135
template <class T>
Paul's avatar
Paul committed
136
137
138
139
140
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
141
template <class T>
Paul's avatar
Paul committed
142
143
144
145
146
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
147
template <>
Paul's avatar
Paul committed
148
149
150
151
152
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
153
template <class T>
Paul's avatar
Paul committed
154
155
156
157
158
159
160
161
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
162
template <class T>
Paul's avatar
Paul committed
163
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
164
{
Paul's avatar
Paul committed
165
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
166
167
}

Paul's avatar
Paul committed
168
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
169

wsttiger's avatar
wsttiger committed
170
171
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
Shucai Xiao's avatar
Shucai Xiao committed
172
    std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
Shucai Xiao's avatar
Shucai Xiao committed
173
174
    check_shapes{input_shapes}.not_broadcasted();
    auto a_strides = inputs[0].strides();
Shucai Xiao's avatar
Shucai Xiao committed
175
176
    auto dim_0     = a_strides.size() - 2;
    if(a_strides.size() > 2)
Shucai Xiao's avatar
Shucai Xiao committed
177
    {
Shucai Xiao's avatar
Shucai Xiao committed
178
179
180
181
182
        if(!std::all_of(a_strides.begin(), a_strides.begin() + dim_0, [&](auto batch_size) {
               return std::all_of(a_strides.begin() + dim_0, a_strides.end(), [&](auto data_size) {
                   return batch_size >= data_size;
               });
           }))
Shucai Xiao's avatar
Shucai Xiao committed
183
        {
Shucai Xiao's avatar
Shucai Xiao committed
184
185
            MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(a_strides) +
                           "} is transposed!");
Shucai Xiao's avatar
Shucai Xiao committed
186
187
188
189
        }
    }

    auto b_strides = inputs[1].strides();
Shucai Xiao's avatar
Shucai Xiao committed
190
    if(b_strides.size() > 2)
Shucai Xiao's avatar
Shucai Xiao committed
191
    {
Shucai Xiao's avatar
Shucai Xiao committed
192
193
194
195
196
        if(!std::all_of(b_strides.begin(), b_strides.begin() + dim_0, [&](auto batch_size) {
               return std::all_of(b_strides.begin() + dim_0, b_strides.end(), [&](auto data_size) {
                   return batch_size >= data_size;
               });
           }))
Shucai Xiao's avatar
Shucai Xiao committed
197
        {
Shucai Xiao's avatar
Shucai Xiao committed
198
199
            MIGRAPHX_THROW("DOT: batch size of b {" + to_string_range(b_strides) +
                           "} is transposed!");
Shucai Xiao's avatar
Shucai Xiao committed
200
201
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
202

Shucai Xiao's avatar
Shucai Xiao committed
203
    return op.compute_shape(input_shapes);
wsttiger's avatar
wsttiger committed
204
}
Shucai Xiao's avatar
Shucai Xiao committed
205

wsttiger's avatar
wsttiger committed
206
207
208
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
209
{
Shucai Xiao's avatar
Shucai Xiao committed
210
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
211
    float beta      = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
212
213
214
215
    if(is_3inputs)
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
216
            hipMemcpyAsync(to_pointer(args[3]),
Shucai Xiao's avatar
Shucai Xiao committed
217
218
219
220
                           to_pointer(args[2]),
                           output_shape.bytes(),
                           hipMemcpyDeviceToDevice,
                           ctx.get_stream().get());
Shucai Xiao's avatar
Shucai Xiao committed
221
        });
Shucai Xiao's avatar
Shucai Xiao committed
222
        beta = op.beta;
Shucai Xiao's avatar
Shucai Xiao committed
223
224
225
226
    }

    auto a_lens = args[0].get_shape().lens();
    auto b_lens = args[1].get_shape().lens();
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    output_shape.visit_type([&](auto as) {
        auto n_dim        = output_shape.lens().size();
        auto dim_1        = n_dim - 1;
        auto dim_0        = n_dim - 2;
        auto alpha_r      = to_rocblas_type(as(op.alpha));
        auto beta_r       = to_rocblas_type(as(beta));
        bool transa       = args[0].get_shape().transposed();
        bool transb       = args[1].get_shape().transposed();
        rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
        rocblas_int ldb   = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
        rocblas_int ldc   = args[2].get_shape().strides()[dim_0];
        auto out_lens     = output_shape.lens();
        rocblas_int m     = out_lens[dim_0];
        rocblas_int n     = out_lens[dim_1];
        rocblas_int k     = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
242
243
244
245
        auto num_matrices = std::accumulate(
            out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
        auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
        if(num_matrices == 1)
246
        {
Shucai Xiao's avatar
Shucai Xiao committed
247
248
249
250
251
252
253
254
255
256
257
258
259
            generic_rocblas_gemm(as,
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 n,
                                 m,
                                 k,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
                                 to_pointer(args[0]),
                                 lda,
                                 &beta_r,
Shucai Xiao's avatar
Shucai Xiao committed
260
                                 (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
Shucai Xiao's avatar
Shucai Xiao committed
261
                                 ldc);
262
263
264
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
Shucai Xiao's avatar
Shucai Xiao committed
281
                (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
Shucai Xiao's avatar
Shucai Xiao committed
282
283
284
                ldc,
                m * n,
                num_matrices);
285
286
        }
    });
Shucai Xiao's avatar
Shucai Xiao committed
287

Shucai Xiao's avatar
Shucai Xiao committed
288
    return (is_3inputs ? args[3] : args[2]);
wsttiger's avatar
wsttiger committed
289
290
291
}

} // namespace gpu
Paul's avatar
Paul committed
292
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
293
} // namespace migraphx