quant_gemm.cpp 5.47 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/context.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
3
#include <migraphx/generate.hpp>
4
5
#include <fstream>
#include <iomanip>
Shucai Xiao's avatar
Shucai Xiao committed
6
7
8
9
10

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

11
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
Shucai Xiao's avatar
Shucai Xiao committed
12
{
13
    std::vector<shape> in_shapes(inputs);
14
    in_shapes.pop_back();
15
    check_shapes{in_shapes}.not_broadcasted();
16
17
    batch_not_transposed(inputs[0].strides());
    batch_not_transposed(inputs[1].strides());
Shucai Xiao's avatar
Shucai Xiao committed
18

19
    return op.compute_shape(in_shapes);
Shucai Xiao's avatar
Shucai Xiao committed
20
21
}

22
23
24
25
26
27
28
29
30
31
32
void rocblas_quant_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
    if(strides.size() <= 2)
        return;
    auto dim_0       = strides.size() - 2;
    auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
    std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
    if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
           return (i < j or i < matrix_size or j < matrix_size);
       }) != batch.end())
    {
Shucai Xiao's avatar
Shucai Xiao committed
33
        MIGRAPHX_THROW("QUANT_DOT: batch size {" + to_string_range(strides) + "} is transposed!");
34
35
36
    }
}

37
argument rocblas_quant_gemm::compute(context& ctx,
Shucai Xiao's avatar
Shucai Xiao committed
38
39
                                     const shape& output_shape,
                                     const std::vector<argument>& args) const
Shucai Xiao's avatar
Shucai Xiao committed
40
{
41
42
43
44
45
46
47
    bool transa     = args[0].get_shape().transposed();
    bool transb     = args[1].get_shape().transposed();
    auto n_dim      = output_shape.lens().size();
    auto dim_1      = n_dim - 1;
    auto dim_0      = n_dim - 2;
    rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
48
    rocblas_int ldc = args[2].get_shape().strides()[dim_0];
49

50
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
51
    int32_t beta    = 0;
Shucai Xiao's avatar
Shucai Xiao committed
52
53
54
55
56
57
58
59
    if(is_3inputs)
    {
        beta = op.beta;
    }

    auto a_lens = args[0].get_shape().lens();
    auto b_lens = args[1].get_shape().lens();
    output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
60
61
        auto alpha_r    = as(op.alpha);
        auto beta_r     = as(beta);
Shucai Xiao's avatar
Shucai Xiao committed
62
63
64
65
        auto out_lens   = output_shape.lens();
        rocblas_int m   = out_lens[dim_0];
        rocblas_int n   = out_lens[dim_1];
        rocblas_int k   = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
66
        auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
67
        assert(k % 4 == 0);
Shucai Xiao's avatar
Shucai Xiao committed
68

Shucai Xiao's avatar
Shucai Xiao committed
69
70
71
72
        auto num_matrices = std::accumulate(
            out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
        if(num_matrices == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
73
            // the rocblas_gemm API handles inputs and output matrices as
Shucai Xiao's avatar
Shucai Xiao committed
74
75
76
            // column-major format. When doing a C = A * B, we actually do
            // C^T = (B^T) * (A^T). That is the reason we input args[1] as
            // A and args[0] as B in calling the rocblas_gemm.
Shucai Xiao's avatar
Shucai Xiao committed
77
            rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
78
                                    transb ? rocblas_operation_transpose : rocblas_operation_none,
79
                                    transa ? rocblas_operation_transpose : rocblas_operation_none,
80
                                    n,
81
                                    m,
Shucai Xiao's avatar
Shucai Xiao committed
82
83
                                    k,
                                    &alpha_r,
84
                                    to_pointer(args.at(1)),
85
86
                                    rocblas_datatype_i8_r,
                                    ldb,
87
                                    to_pointer(args.at(0)),
88
89
                                    rocblas_datatype_i8_r,
                                    lda,
Shucai Xiao's avatar
Shucai Xiao committed
90
91
92
93
                                    &beta_r,
                                    to_pointer(args[2]),
                                    rocblas_datatype_i32_r,
                                    ldc,
94
                                    is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
Shucai Xiao's avatar
Shucai Xiao committed
95
96
97
98
99
100
101
102
                                    rocblas_datatype_i32_r,
                                    ldc,
                                    rocblas_datatype_i32_r,
                                    rocblas_gemm_algo_standard,
                                    0,
                                    0,
                                    nullptr,
                                    nullptr);
Shucai Xiao's avatar
Shucai Xiao committed
103
104
105
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
106
            rocblas_gemm_strided_batched_ex(
Shucai Xiao's avatar
Shucai Xiao committed
107
108
109
110
111
112
113
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
114
                to_pointer(args.at(1)),
115
                rocblas_datatype_i8_r,
Shucai Xiao's avatar
Shucai Xiao committed
116
117
                ldb,
                k * n,
118
                to_pointer(args.at(0)),
119
                rocblas_datatype_i8_r,
Shucai Xiao's avatar
Shucai Xiao committed
120
121
122
                lda,
                m * k,
                &beta_r,
123
124
125
126
                to_pointer(args[2]),
                rocblas_datatype_i32_r,
                ldc,
                m * n,
127
                is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
128
                rocblas_datatype_i32_r,
Shucai Xiao's avatar
Shucai Xiao committed
129
130
                ldc,
                m * n,
131
132
133
                num_matrices,
                rocblas_datatype_i32_r,
                rocblas_gemm_algo_standard,
Shucai Xiao's avatar
Shucai Xiao committed
134
135
136
137
                0,
                0,
                nullptr,
                nullptr);
Shucai Xiao's avatar
Shucai Xiao committed
138
139
140
        }
    });

141
    return is_3inputs ? args[3] : args[2];
Shucai Xiao's avatar
Shucai Xiao committed
142
143
144
145
146
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx