gemm_impl.cpp 6.01 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
#include <rocblas.h>
2
#include <migraphx/gpu/gemm_impl.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
3
4
5
6
7

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

8
rocblas_datatype get_type(shape::type_t type)
Shucai Xiao's avatar
Shucai Xiao committed
9
{
10
    switch(type)
11
    {
12
13
14
15
16
17
18
    case shape::double_type: return rocblas_datatype_f64_r;
    case shape::float_type: return rocblas_datatype_f32_r;
    case shape::half_type: return rocblas_datatype_f16_r;
    case shape::int8_type: return rocblas_datatype_i8_r;
    case shape::uint8_type: return rocblas_datatype_u8_r;
    case shape::int32_type: return rocblas_datatype_i32_r;
    case shape::uint32_type: return rocblas_datatype_u32_r;
19
    case shape::bool_type:
20
21
22
23
    case shape::uint16_type:
    case shape::int16_type:
    case shape::int64_type:
    case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
24
    }
25
26

    MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
27
28
}

29
30
31
32
33
34
35
36
37
template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs)
{
    if constexpr(sizeof...(Ts) == sizeof...(Us))
        return f(xs...);
    else
        return f(xs..., nullptr, nullptr);
}

38
39
40
template <class T>
void gemm_impl(
    context& ctx, const shape& output_shape, const std::vector<argument>& args, T alpha, T beta)
Shucai Xiao's avatar
Shucai Xiao committed
41
{
42
43
44
45
46
47
48
    bool transa     = args[0].get_shape().transposed();
    bool transb     = args[1].get_shape().transposed();
    auto n_dim      = output_shape.lens().size();
    auto dim_1      = n_dim - 1;
    auto dim_0      = n_dim - 2;
    rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
49
    rocblas_int ldc = args[2].get_shape().strides()[dim_0];
50

51
    bool is_3inputs = (args.size() == 4);
52
53
54
55
56
57
58
    if(!is_3inputs)
    {
        beta = 0;
    }
    rocblas_datatype arg_type = get_type(args[0].get_shape().type());
    auto output_type          = arg_type;
    if(output_type == rocblas_datatype_i8_r)
Shucai Xiao's avatar
Shucai Xiao committed
59
    {
60
        output_type = rocblas_datatype_i32_r;
Shucai Xiao's avatar
Shucai Xiao committed
61
    }
62
    auto compute_type = output_type;
Shucai Xiao's avatar
Shucai Xiao committed
63
64
65
66

    auto a_lens = args[0].get_shape().lens();
    auto b_lens = args[1].get_shape().lens();
    output_shape.visit_type([&](auto as) {
67
        auto alpha_r    = as(alpha);
Shucai Xiao's avatar
Shucai Xiao committed
68
        auto beta_r     = as(beta);
Shucai Xiao's avatar
Shucai Xiao committed
69
70
71
72
        auto out_lens   = output_shape.lens();
        rocblas_int m   = out_lens[dim_0];
        rocblas_int n   = out_lens[dim_1];
        rocblas_int k   = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
73
        auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
74
75
76
77
        if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0)
        {
            MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
        }
Shucai Xiao's avatar
Shucai Xiao committed
78

Shucai Xiao's avatar
Shucai Xiao committed
79
80
81
82
        auto num_matrices = std::accumulate(
            out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
        if(num_matrices == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
83
            // the rocblas_gemm API handles inputs and output matrices as
Shucai Xiao's avatar
Shucai Xiao committed
84
85
86
            // column-major format. When doing a C = A * B, we actually do
            // C^T = (B^T) * (A^T). That is the reason we input args[1] as
            // A and args[0] as B in calling the rocblas_gemm.
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
            rocblas_invoke(&rocblas_gemm_ex,
                           ctx.get_stream().get_rocblas(),
                           transb ? rocblas_operation_transpose : rocblas_operation_none,
                           transa ? rocblas_operation_transpose : rocblas_operation_none,
                           n,
                           m,
                           k,
                           &alpha_r,
                           to_pointer(args.at(1)),
                           arg_type,
                           ldb,
                           to_pointer(args.at(0)),
                           arg_type,
                           lda,
                           &beta_r,
                           to_pointer(args[2]),
                           output_type,
                           ldc,
                           is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
                           output_type,
                           ldc,
                           compute_type,
                           rocblas_gemm_algo_standard,
                           0,
                           0);
Shucai Xiao's avatar
Shucai Xiao committed
112
113
114
        }
        else
        {
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
            rocblas_invoke(&rocblas_gemm_strided_batched_ex,
                           ctx.get_stream().get_rocblas(),
                           transb ? rocblas_operation_transpose : rocblas_operation_none,
                           transa ? rocblas_operation_transpose : rocblas_operation_none,
                           n,
                           m,
                           k,
                           &alpha_r,
                           to_pointer(args.at(1)),
                           arg_type,
                           ldb,
                           k * n,
                           to_pointer(args.at(0)),
                           arg_type,
                           lda,
                           m * k,
                           &beta_r,
                           to_pointer(args[2]),
                           output_type,
                           ldc,
                           m * n,
                           is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
                           output_type,
                           ldc,
                           m * n,
                           num_matrices,
                           compute_type,
                           rocblas_gemm_algo_standard,
                           0,
                           0);
Shucai Xiao's avatar
Shucai Xiao committed
145
146
        }
    });
147
}
Shucai Xiao's avatar
Shucai Xiao committed
148

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
void gemm(context& ctx,
          const shape& output_shape,
          const std::vector<argument>& args,
          float alpha,
          float beta)
{
    gemm_impl(ctx, output_shape, args, alpha, beta);
}

void gemm(context& ctx,
          const shape& output_shape,
          const std::vector<argument>& args,
          int32_t alpha,
          int32_t beta)
{
    gemm_impl(ctx, output_shape, args, alpha, beta);
Shucai Xiao's avatar
Shucai Xiao committed
165
166
167
168
169
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx