gemm_impl.cpp 5.55 KB
Newer Older
1
2
#include <rocblas-types.h>
#include <migraphx/gpu/gemm_impl.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
3
4
5
6
7

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

8
rocblas_datatype get_type(shape::type_t type)
Shucai Xiao's avatar
Shucai Xiao committed
9
{
10
    switch(type)
11
    {
12
13
14
15
16
17
18
19
20
21
22
    case shape::double_type: return rocblas_datatype_f64_r;
    case shape::float_type: return rocblas_datatype_f32_r;
    case shape::half_type: return rocblas_datatype_f16_r;
    case shape::int8_type: return rocblas_datatype_i8_r;
    case shape::uint8_type: return rocblas_datatype_u8_r;
    case shape::int32_type: return rocblas_datatype_i32_r;
    case shape::uint32_type: return rocblas_datatype_u32_r;
    case shape::uint16_type:
    case shape::int16_type:
    case shape::int64_type:
    case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
23
    }
24
25

    MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
26
27
}

28
29
30
template <class T>
void gemm_impl(
    context& ctx, const shape& output_shape, const std::vector<argument>& args, T alpha, T beta)
Shucai Xiao's avatar
Shucai Xiao committed
31
{
32
33
34
35
36
37
38
    bool transa     = args[0].get_shape().transposed();
    bool transb     = args[1].get_shape().transposed();
    auto n_dim      = output_shape.lens().size();
    auto dim_1      = n_dim - 1;
    auto dim_0      = n_dim - 2;
    rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
39
    rocblas_int ldc = args[2].get_shape().strides()[dim_0];
40

41
    bool is_3inputs = (args.size() == 4);
42
43
44
45
46
47
48
    if(!is_3inputs)
    {
        beta = 0;
    }
    rocblas_datatype arg_type = get_type(args[0].get_shape().type());
    auto output_type          = arg_type;
    if(output_type == rocblas_datatype_i8_r)
Shucai Xiao's avatar
Shucai Xiao committed
49
    {
50
        output_type = rocblas_datatype_i32_r;
Shucai Xiao's avatar
Shucai Xiao committed
51
    }
52
    auto compute_type = output_type;
Shucai Xiao's avatar
Shucai Xiao committed
53
54
55
56

    auto a_lens = args[0].get_shape().lens();
    auto b_lens = args[1].get_shape().lens();
    output_shape.visit_type([&](auto as) {
57
        auto alpha_r    = as(alpha);
Shucai Xiao's avatar
Shucai Xiao committed
58
        auto beta_r     = as(beta);
Shucai Xiao's avatar
Shucai Xiao committed
59
60
61
62
        auto out_lens   = output_shape.lens();
        rocblas_int m   = out_lens[dim_0];
        rocblas_int n   = out_lens[dim_1];
        rocblas_int k   = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
63
        auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
64
65
66
67
        if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0)
        {
            MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
        }
Shucai Xiao's avatar
Shucai Xiao committed
68

Shucai Xiao's avatar
Shucai Xiao committed
69
70
71
72
        auto num_matrices = std::accumulate(
            out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
        if(num_matrices == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
73
            // the rocblas_gemm API handles inputs and output matrices as
Shucai Xiao's avatar
Shucai Xiao committed
74
75
76
            // column-major format. When doing a C = A * B, we actually do
            // C^T = (B^T) * (A^T). That is the reason we input args[1] as
            // A and args[0] as B in calling the rocblas_gemm.
Shucai Xiao's avatar
Shucai Xiao committed
77
            rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
Shucai Xiao's avatar
Shucai Xiao committed
78
79
80
81
82
83
84
                            transb ? rocblas_operation_transpose : rocblas_operation_none,
                            transa ? rocblas_operation_transpose : rocblas_operation_none,
                            n,
                            m,
                            k,
                            &alpha_r,
                            to_pointer(args.at(1)),
85
                            arg_type,
Shucai Xiao's avatar
Shucai Xiao committed
86
87
                            ldb,
                            to_pointer(args.at(0)),
88
                            arg_type,
Shucai Xiao's avatar
Shucai Xiao committed
89
90
91
                            lda,
                            &beta_r,
                            to_pointer(args[2]),
92
                            output_type,
Shucai Xiao's avatar
Shucai Xiao committed
93
94
                            ldc,
                            is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
95
                            output_type,
Shucai Xiao's avatar
Shucai Xiao committed
96
                            ldc,
97
                            compute_type,
Shucai Xiao's avatar
Shucai Xiao committed
98
99
100
101
102
                            rocblas_gemm_algo_standard,
                            0,
                            0,
                            nullptr,
                            nullptr);
Shucai Xiao's avatar
Shucai Xiao committed
103
104
105
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
106
            rocblas_gemm_strided_batched_ex(
Shucai Xiao's avatar
Shucai Xiao committed
107
108
109
110
111
112
113
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
114
                to_pointer(args.at(1)),
115
                arg_type,
Shucai Xiao's avatar
Shucai Xiao committed
116
117
                ldb,
                k * n,
118
                to_pointer(args.at(0)),
119
                arg_type,
Shucai Xiao's avatar
Shucai Xiao committed
120
121
122
                lda,
                m * k,
                &beta_r,
123
                to_pointer(args[2]),
124
                output_type,
125
126
                ldc,
                m * n,
127
                is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
128
                output_type,
Shucai Xiao's avatar
Shucai Xiao committed
129
130
                ldc,
                m * n,
131
                num_matrices,
132
                compute_type,
133
                rocblas_gemm_algo_standard,
Shucai Xiao's avatar
Shucai Xiao committed
134
135
136
137
                0,
                0,
                nullptr,
                nullptr);
Shucai Xiao's avatar
Shucai Xiao committed
138
139
        }
    });
140
}
Shucai Xiao's avatar
Shucai Xiao committed
141

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
void gemm(context& ctx,
          const shape& output_shape,
          const std::vector<argument>& args,
          float alpha,
          float beta)
{
    gemm_impl(ctx, output_shape, args, alpha, beta);
}

void gemm(context& ctx,
          const shape& output_shape,
          const std::vector<argument>& args,
          int32_t alpha,
          int32_t beta)
{
    gemm_impl(ctx, output_shape, args, alpha, beta);
Shucai Xiao's avatar
Shucai Xiao committed
158
159
160
161
162
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx