gemm_impl.cpp 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24
#include <rocblas/rocblas.h>
25
#include <migraphx/gpu/gemm_impl.hpp>
26
#include <migraphx/reduce_dims.hpp>
27
#include <migraphx/permutation.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
28
29
30
31
32

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

33
rocblas_datatype get_type(shape::type_t type)
Shucai Xiao's avatar
Shucai Xiao committed
34
{
35
    switch(type)
36
    {
37
38
39
40
41
42
43
    case shape::double_type: return rocblas_datatype_f64_r;
    case shape::float_type: return rocblas_datatype_f32_r;
    case shape::half_type: return rocblas_datatype_f16_r;
    case shape::int8_type: return rocblas_datatype_i8_r;
    case shape::uint8_type: return rocblas_datatype_u8_r;
    case shape::int32_type: return rocblas_datatype_i32_r;
    case shape::uint32_type: return rocblas_datatype_u32_r;
Paul Fultz II's avatar
Paul Fultz II committed
44
    case shape::tuple_type:
45
    case shape::bool_type:
46
47
48
49
    case shape::uint16_type:
    case shape::int16_type:
    case shape::int64_type:
    case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
50
    }
51
52

    MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
53
54
}

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
void blas_shape(const shape& s)
{
    if(s.lens().size() < 2)
        return;
    if(std::none_of(s.strides().end() - 2, s.strides().end(), [&](auto i) { return i == 1; }))
        MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1");
    if(s.lens().size() < 3)
        return;
    shape batch_shape{s.type(),
                      {s.lens().begin(), s.lens().end() - 2},
                      {s.strides().begin(), s.strides().end() - 2}};
    auto batch_shapes = reduce_dims({batch_shape});
    if(batch_shapes.front().lens().size() != 1)
        MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
}

71
72
73
74
75
76
77
78
79
80
81
82
83
shape transpose_batch(const shape& s, unsigned trans_batch)
{
    if(trans_batch == 0)
        return s;
    if(s.lens().size() < 3)
        return s;
    auto batch = s.lens().size() - 3;
    std::vector<int64_t> perm(s.lens().size());
    std::iota(perm.begin(), perm.end(), 0);
    std::swap(perm[batch], perm[batch + trans_batch]);
    return shape::from_permutation(s.type(), s.lens(), perm);
}

84
85
86
87
88
89
90
91
92
template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs)
{
    if constexpr(sizeof...(Ts) == sizeof...(Us))
        return f(xs...);
    else
        return f(xs..., nullptr, nullptr);
}

93
94
95
96
97
98
99
100
101
102
103
104
static bool is_transposed(const shape& s)
{
    if(not s.transposed())
        return false;
    return s.strides().back() != 1;
}

static rocblas_int get_batch_stride(const argument& a)
{
    return a.get_shape().strides()[a.get_shape().strides().size() - 3];
}

105
template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
106
107
108
109
110
void gemm_impl(context& ctx,
               const shape& output_shape,
               const std::vector<argument>& args,
               T alpha,
               T beta,
111
112
               bool int8_x4_format,
               bool compute_fp32)
Shucai Xiao's avatar
Shucai Xiao committed
113
{
114
    const bool is_3inputs = (args.size() == 4);
115
    if(not is_3inputs)
116
117
118
119
    {
        beta = 0;
    }

120
121
    bool transa     = is_transposed(args[0].get_shape());
    bool transb     = is_transposed(args[1].get_shape());
122
123
124
125
126
    auto n_dim      = output_shape.lens().size();
    auto dim_1      = n_dim - 1;
    auto dim_0      = n_dim - 2;
    rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
127
    rocblas_int ldc = args[2].get_shape().strides()[dim_0];
128
    rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
129

130
131
132
    rocblas_datatype arg_type = get_type(args[0].get_shape().type());
    auto output_type          = arg_type;
    if(output_type == rocblas_datatype_i8_r)
Shucai Xiao's avatar
Shucai Xiao committed
133
    {
134
        output_type = rocblas_datatype_i32_r;
Shucai Xiao's avatar
Shucai Xiao committed
135
    }
136
    auto compute_type = output_type;
137
138
139
140
141
    if(compute_fp32)
    {
        if(arg_type == rocblas_datatype_f16_r)
            compute_type = rocblas_datatype_f32_r;
    }
Shucai Xiao's avatar
Shucai Xiao committed
142

Shucai Xiao's avatar
Shucai Xiao committed
143
144
145
146
147
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
    rocblas_gemm_flags flag =
        int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
    (void)int8_x4_format;
148
    int flag = 0;
Shucai Xiao's avatar
Shucai Xiao committed
149
150
#endif

Shucai Xiao's avatar
Shucai Xiao committed
151
152
153
    auto a_lens = args[0].get_shape().lens();
    auto b_lens = args[1].get_shape().lens();
    output_shape.visit_type([&](auto as) {
154
155
156
157
158
159
160
161
162
163
164
165
166
        auto alpha_r = as(alpha);
        auto beta_r  = as(beta);

        // use void pointer to select different data type if using fp32 mode
        void* alpha_v = &alpha_r;
        void* beta_v  = &beta_r;

        if(compute_fp32)
        {
            alpha_v = &alpha;
            beta_v  = &beta;
        }

Shucai Xiao's avatar
Shucai Xiao committed
167
168
169
170
        auto out_lens   = output_shape.lens();
        rocblas_int m   = out_lens[dim_0];
        rocblas_int n   = out_lens[dim_1];
        rocblas_int k   = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
171
        auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
Shucai Xiao's avatar
Shucai Xiao committed
172
        if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
173
174
175
        {
            MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
        }
Shucai Xiao's avatar
Shucai Xiao committed
176

Shucai Xiao's avatar
Shucai Xiao committed
177
178
        auto num_matrices = std::accumulate(
            out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
179
        if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0))
Shucai Xiao's avatar
Shucai Xiao committed
180
        {
181
182
183
184
185
            // If the batch dimension of B is broadcasted, then we can
            // multiply m by the batch_size and use rocblas_gemm_ex
            // instead of rocblas_gemm_strided_batched_ex.
            m *= num_matrices;

Shucai Xiao's avatar
Shucai Xiao committed
186
            // the rocblas_gemm API handles inputs and output matrices as
Shucai Xiao's avatar
Shucai Xiao committed
187
188
189
            // column-major format. When doing a C = A * B, we actually do
            // C^T = (B^T) * (A^T). That is the reason we input args[1] as
            // A and args[0] as B in calling the rocblas_gemm.
190
191
192
193
194
195
196
            rocblas_invoke(&rocblas_gemm_ex,
                           ctx.get_stream().get_rocblas(),
                           transb ? rocblas_operation_transpose : rocblas_operation_none,
                           transa ? rocblas_operation_transpose : rocblas_operation_none,
                           n,
                           m,
                           k,
197
                           alpha_v,
198
199
200
201
202
203
                           to_pointer(args.at(1)),
                           arg_type,
                           ldb,
                           to_pointer(args.at(0)),
                           arg_type,
                           lda,
204
                           beta_v,
205
206
207
208
209
                           to_pointer(args[2]),
                           output_type,
                           ldc,
                           is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
                           output_type,
210
                           ldd,
211
212
213
                           compute_type,
                           rocblas_gemm_algo_standard,
                           0,
Shucai Xiao's avatar
Shucai Xiao committed
214
                           flag);
Shucai Xiao's avatar
Shucai Xiao committed
215
216
217
        }
        else
        {
218
219
220
            auto a_stride = get_batch_stride(args[0]);
            auto b_stride = get_batch_stride(args[1]);
            auto c_stride = get_batch_stride(args[2]);
221
            auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
222
223
224
225
226
227
228
            rocblas_invoke(&rocblas_gemm_strided_batched_ex,
                           ctx.get_stream().get_rocblas(),
                           transb ? rocblas_operation_transpose : rocblas_operation_none,
                           transa ? rocblas_operation_transpose : rocblas_operation_none,
                           n,
                           m,
                           k,
229
                           alpha_v,
230
231
232
                           to_pointer(args.at(1)),
                           arg_type,
                           ldb,
233
                           b_stride,
234
235
236
                           to_pointer(args.at(0)),
                           arg_type,
                           lda,
237
                           a_stride,
238
                           beta_v,
239
240
241
                           to_pointer(args[2]),
                           output_type,
                           ldc,
242
                           c_stride,
243
244
                           is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
                           output_type,
245
246
                           ldd,
                           d_stride,
247
248
249
250
                           num_matrices,
                           compute_type,
                           rocblas_gemm_algo_standard,
                           0,
Shucai Xiao's avatar
Shucai Xiao committed
251
                           flag);
Shucai Xiao's avatar
Shucai Xiao committed
252
253
        }
    });
254
}
Shucai Xiao's avatar
Shucai Xiao committed
255

256
257
258
259
void gemm(context& ctx,
          const shape& output_shape,
          const std::vector<argument>& args,
          float alpha,
Shucai Xiao's avatar
Shucai Xiao committed
260
          float beta,
261
262
          bool int8_x4_format,
          bool compute_fp32)
263
{
264
    gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
265
266
267
268
269
270
}

void gemm(context& ctx,
          const shape& output_shape,
          const std::vector<argument>& args,
          int32_t alpha,
Shucai Xiao's avatar
Shucai Xiao committed
271
          int32_t beta,
272
273
          bool int8_x4_format,
          bool compute_fp32)
274
{
275
    gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
Shucai Xiao's avatar
Shucai Xiao committed
276
277
278
279
280
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx