gemm.cpp 8.75 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/gpu/gemm.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/gpu/context.hpp>
3
#include <migraphx/gpu/device/add.hpp>
wsttiger's avatar
wsttiger committed
4

Paul's avatar
Paul committed
5
namespace migraphx {
Paul's avatar
Paul committed
6
inline namespace MIGRAPHX_INLINE_NS {
wsttiger's avatar
wsttiger committed
7
8
namespace gpu {

9
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
10
rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs)
11
{
Shucai Xiao's avatar
Shucai Xiao committed
12
    return rocblas_sscal(std::forward<Ts>(xs)...);
13
14
15
}

template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
16
rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs)
17
{
Shucai Xiao's avatar
Shucai Xiao committed
18
19
20
21
22
23
24
25
26
27
28
29
30
    return rocblas_dscal(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
    return rocblas_haxpy(std::forward<Ts>(xs)...);
31
32
33
}

template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
34
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
35
{
Shucai Xiao's avatar
Shucai Xiao committed
36
37
38
39
40
41
42
    return rocblas_saxpy(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
    return rocblas_daxpy(std::forward<Ts>(xs)...);
43
44
45
}

template <class T, class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
    return rocblas_sdot(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
    return rocblas_ddot(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
    return rocblas_sgemv(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
    return rocblas_dgemv(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
    return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
    return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
    return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}

template <class T, class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
107
108
109
110
{
    MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}

Paul's avatar
Paul committed
111
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
112
rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
Paul's avatar
Paul committed
113
{
Shucai Xiao's avatar
Shucai Xiao committed
114
    return rocblas_sgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
115
116
}

Paul's avatar
Paul committed
117
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
118
rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
Paul's avatar
Paul committed
119
{
Shucai Xiao's avatar
Shucai Xiao committed
120
    return rocblas_dgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
121
122
}

Paul's avatar
Paul committed
123
template <class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
124
rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
Paul's avatar
Paul committed
125
{
Shucai Xiao's avatar
Shucai Xiao committed
126
    return rocblas_hgemm(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
127
128
}

Paul's avatar
Paul committed
129
template <class T, class... Ts>
Shucai Xiao's avatar
Shucai Xiao committed
130
rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
Paul's avatar
Paul committed
131
{
132
    MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
Paul's avatar
Paul committed
133
134
}

Paul's avatar
Paul committed
135
template <class T>
Paul's avatar
Paul committed
136
137
138
139
140
struct compute_rocblas_type
{
    using type = T;
};

Paul's avatar
Paul committed
141
template <class T>
Paul's avatar
Paul committed
142
143
144
145
146
struct compute_rocblas_type<const T>
{
    using type = const typename compute_rocblas_type<T>::type;
};

Paul's avatar
Paul committed
147
template <>
Paul's avatar
Paul committed
148
149
150
151
152
struct compute_rocblas_type<half>
{
    using type = rocblas_half;
};

Paul's avatar
Paul committed
153
template <class T>
Paul's avatar
Paul committed
154
155
156
157
158
159
160
161
using rb_type = typename compute_rocblas_type<T>::type;

template <class T>
rb_type<T> to_rocblas_type(T x)
{
    return reinterpret_cast<const rb_type<T>&>(x);
}

Paul's avatar
Paul committed
162
template <class T>
Paul's avatar
Paul committed
163
rb_type<T>* to_rocblas_type(T* x)
Paul's avatar
Paul committed
164
{
Paul's avatar
Paul committed
165
    return reinterpret_cast<rb_type<T>*>(x);
Paul's avatar
Paul committed
166
167
}

Paul's avatar
Paul committed
168
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
Paul's avatar
Paul committed
169

wsttiger's avatar
wsttiger committed
170
171
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
Shucai Xiao's avatar
Shucai Xiao committed
172
    std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
Shucai Xiao's avatar
Shucai Xiao committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    check_shapes{input_shapes}.not_broadcasted();
    auto a_strides = inputs[0].strides();
    auto dim_0 = a_strides.size() - 2;
    if (a_strides.size() > 2)
    {
        if (!std::all_of(a_strides.begin(), a_strides.begin() + dim_0, [&](auto batch_size) {
            return std::all_of(a_strides.begin() + dim_0, a_strides.end(), [&](auto data_size) {
                return batch_size >= data_size;
            });
        }))
        {
            MIGRAPHX_THROW("DOT: batch size of a {" + to_string_range(a_strides) + "} is transposed!");
        }
    }

    auto b_strides = inputs[1].strides();
    if (b_strides.size() > 2)
    {
        if (!std::all_of(b_strides.begin(), b_strides.begin() + dim_0, [&](auto batch_size) {
            return std::all_of(b_strides.begin() + dim_0, b_strides.end(), [&](auto data_size) {
                return batch_size >= data_size;
            });
        }))
        {
            MIGRAPHX_THROW("DOT: batch size of b {" + to_string_range(b_strides) + "} is transposed!");
        }
    }
    
Shucai Xiao's avatar
Shucai Xiao committed
201
    return op.compute_shape(input_shapes);
wsttiger's avatar
wsttiger committed
202
}
Shucai Xiao's avatar
Shucai Xiao committed
203

wsttiger's avatar
wsttiger committed
204
205
206
argument miopen_gemm::compute(context& ctx,
                              const shape& output_shape,
                              const std::vector<argument>& args) const
wsttiger's avatar
wsttiger committed
207
{
Shucai Xiao's avatar
Shucai Xiao committed
208
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
209
    float beta      = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
210
211
212
213
    if(is_3inputs)
    {
        output_shape.visit_type([&](auto as) {
            auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
214
            hipMemcpyAsync(to_pointer(args[3]),
Shucai Xiao's avatar
Shucai Xiao committed
215
216
217
218
                           to_pointer(args[2]),
                           output_shape.bytes(),
                           hipMemcpyDeviceToDevice,
                           ctx.get_stream().get());
Shucai Xiao's avatar
Shucai Xiao committed
219
        });
Shucai Xiao's avatar
Shucai Xiao committed
220
        beta = op.beta;
Shucai Xiao's avatar
Shucai Xiao committed
221
222
223
224
    }

    auto a_lens = args[0].get_shape().lens();
    auto b_lens = args[1].get_shape().lens();
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    output_shape.visit_type([&](auto as) {
        auto n_dim        = output_shape.lens().size();
        auto dim_1        = n_dim - 1;
        auto dim_0        = n_dim - 2;
        auto alpha_r      = to_rocblas_type(as(op.alpha));
        auto beta_r       = to_rocblas_type(as(beta));
        bool transa       = args[0].get_shape().transposed();
        bool transb       = args[1].get_shape().transposed();
        rocblas_int lda   = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
        rocblas_int ldb   = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
        rocblas_int ldc   = args[2].get_shape().strides()[dim_0];
        auto out_lens     = output_shape.lens();
        rocblas_int m     = out_lens[dim_0];
        rocblas_int n     = out_lens[dim_1];
        rocblas_int k     = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
240
241
242
243
        auto num_matrices = std::accumulate(
            out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
        auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
        if(num_matrices == 1)
244
        {
Shucai Xiao's avatar
Shucai Xiao committed
245
246
247
248
249
250
251
252
253
254
255
256
257
            generic_rocblas_gemm(as,
                                 ctx.get_stream().get_rocblas(),
                                 transb ? rocblas_operation_transpose : rocblas_operation_none,
                                 transa ? rocblas_operation_transpose : rocblas_operation_none,
                                 n,
                                 m,
                                 k,
                                 &alpha_r,
                                 to_pointer(args[1]),
                                 ldb,
                                 to_pointer(args[0]),
                                 lda,
                                 &beta_r,
Shucai Xiao's avatar
Shucai Xiao committed
258
                                 (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
Shucai Xiao's avatar
Shucai Xiao committed
259
                                 ldc);
260
261
262
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
            generic_rocblas_batched_gemm(
                as,
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
                to_pointer(args[1]),
                ldb,
                k * n,
                to_pointer(args[0]),
                lda,
                m * k,
                &beta_r,
Shucai Xiao's avatar
Shucai Xiao committed
279
                (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
Shucai Xiao's avatar
Shucai Xiao committed
280
281
282
                ldc,
                m * n,
                num_matrices);
283
284
        }
    });
Shucai Xiao's avatar
Shucai Xiao committed
285

Shucai Xiao's avatar
Shucai Xiao committed
286
    return (is_3inputs ? args[3] : args[2]);
wsttiger's avatar
wsttiger committed
287
288
289
}

} // namespace gpu
Paul's avatar
Paul committed
290
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
291
} // namespace migraphx