quant_gemm.cpp 5.27 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/context.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
3
#include <migraphx/generate.hpp>
4
5
#include <fstream>
#include <iomanip>
Shucai Xiao's avatar
Shucai Xiao committed
6
7
8
9
10

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

11
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
Shucai Xiao's avatar
Shucai Xiao committed
12
{
13
    std::vector<shape> in_shapes(inputs);
14
    in_shapes.pop_back();
15
    check_shapes{in_shapes}.not_broadcasted();
16
17
    batch_not_transposed(inputs[0].strides());
    batch_not_transposed(inputs[1].strides());
Shucai Xiao's avatar
Shucai Xiao committed
18

19
    return op.compute_shape(in_shapes);
Shucai Xiao's avatar
Shucai Xiao committed
20
21
}

22
23
24
25
26
27
28
29
30
31
32
void rocblas_quant_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
    if(strides.size() <= 2)
        return;
    auto dim_0       = strides.size() - 2;
    auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
    std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
    if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
           return (i < j or i < matrix_size or j < matrix_size);
       }) != batch.end())
    {
Shucai Xiao's avatar
Shucai Xiao committed
33
        MIGRAPHX_THROW("QUANT_DOT: batch size {" + to_string_range(strides) + "} is transposed!");
34
35
36
    }
}

37
argument rocblas_quant_gemm::compute(context& ctx,
Shucai Xiao's avatar
Shucai Xiao committed
38
39
                                     const shape& output_shape,
                                     const std::vector<argument>& args) const
Shucai Xiao's avatar
Shucai Xiao committed
40
{
41
42
43
44
45
46
47
    bool transa     = args[0].get_shape().transposed();
    bool transb     = args[1].get_shape().transposed();
    auto n_dim      = output_shape.lens().size();
    auto dim_1      = n_dim - 1;
    auto dim_0      = n_dim - 2;
    rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
    rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
48
    rocblas_int ldc = args[2].get_shape().strides()[dim_0];
49

50
    bool is_3inputs = (args.size() == 4);
Shucai Xiao's avatar
Shucai Xiao committed
51
    int32_t beta    = 0;
Shucai Xiao's avatar
Shucai Xiao committed
52
53
54
55
56
57
58
59
    if(is_3inputs)
    {
        beta = op.beta;
    }

    auto a_lens = args[0].get_shape().lens();
    auto b_lens = args[1].get_shape().lens();
    output_shape.visit_type([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
60
61
        auto alpha_r    = as(op.alpha);
        auto beta_r     = as(beta);
Shucai Xiao's avatar
Shucai Xiao committed
62
63
64
65
        auto out_lens   = output_shape.lens();
        rocblas_int m   = out_lens[dim_0];
        rocblas_int n   = out_lens[dim_1];
        rocblas_int k   = args[0].get_shape().lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
66
        auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
67
        assert(k % 4 == 0);
Shucai Xiao's avatar
Shucai Xiao committed
68

Shucai Xiao's avatar
Shucai Xiao committed
69
70
71
72
        auto num_matrices = std::accumulate(
            out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
        if(num_matrices == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
73
            // the rocblas_gemm API handles inputs and output matrices as
Shucai Xiao's avatar
Shucai Xiao committed
74
75
76
            // column-major format. When doing a C = A * B, we actually do
            // C^T = (B^T) * (A^T). That is the reason we input args[1] as
            // A and args[0] as B in calling the rocblas_gemm.
Shucai Xiao's avatar
Shucai Xiao committed
77
            rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
Shucai Xiao's avatar
Shucai Xiao committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
                            transb ? rocblas_operation_transpose : rocblas_operation_none,
                            transa ? rocblas_operation_transpose : rocblas_operation_none,
                            n,
                            m,
                            k,
                            &alpha_r,
                            to_pointer(args.at(1)),
                            rocblas_datatype_i8_r,
                            ldb,
                            to_pointer(args.at(0)),
                            rocblas_datatype_i8_r,
                            lda,
                            &beta_r,
                            to_pointer(args[2]),
                            rocblas_datatype_i32_r,
                            ldc,
                            is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
                            rocblas_datatype_i32_r,
                            ldc,
                            rocblas_datatype_i32_r,
                            rocblas_gemm_algo_standard,
                            0,
                            0,
                            nullptr,
                            nullptr);
Shucai Xiao's avatar
Shucai Xiao committed
103
104
105
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
106
            rocblas_gemm_strided_batched_ex(
Shucai Xiao's avatar
Shucai Xiao committed
107
108
109
110
111
112
113
                ctx.get_stream().get_rocblas(),
                transb ? rocblas_operation_transpose : rocblas_operation_none,
                transa ? rocblas_operation_transpose : rocblas_operation_none,
                n,
                m,
                k,
                &alpha_r,
114
                to_pointer(args.at(1)),
115
                rocblas_datatype_i8_r,
Shucai Xiao's avatar
Shucai Xiao committed
116
117
                ldb,
                k * n,
118
                to_pointer(args.at(0)),
119
                rocblas_datatype_i8_r,
Shucai Xiao's avatar
Shucai Xiao committed
120
121
122
                lda,
                m * k,
                &beta_r,
123
124
125
126
                to_pointer(args[2]),
                rocblas_datatype_i32_r,
                ldc,
                m * n,
127
                is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
128
                rocblas_datatype_i32_r,
Shucai Xiao's avatar
Shucai Xiao committed
129
130
                ldc,
                m * n,
131
132
133
                num_matrices,
                rocblas_datatype_i32_r,
                rocblas_gemm_algo_standard,
Shucai Xiao's avatar
Shucai Xiao committed
134
135
136
137
                0,
                0,
                nullptr,
                nullptr);
Shucai Xiao's avatar
Shucai Xiao committed
138
139
140
        }
    });

141
    return is_3inputs ? args[3] : args[2];
Shucai Xiao's avatar
Shucai Xiao committed
142
143
144
145
146
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx