fused_linear.hpp 15.7 KB
Newer Older
1
2
#pragma once
#include "infinicore/nn/linear.hpp"
3
#include "infinicore/quantization.hpp"
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

#include "../engine/distributed/communication_group.hpp"

namespace infinilm::layers {
class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear {
public:
    explicit QKVParallelLinear(size_t hidden_size,
                               size_t q_dim, size_t k_dim, size_t v_dim,
                               size_t num_q_head, size_t num_k_head, size_t num_v_head,
                               bool q_bias, bool k_bias, bool v_bias,
                               const infinicore::DataType &dtype = infinicore::DataType::F32,
                               const infinicore::Device &device = infinicore::Device(),
                               engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

    // A more common case where all heads have the same dimension
    explicit QKVParallelLinear(size_t hidden_size,
                               size_t head_dim,
                               size_t num_q_head, size_t num_kv_head,
                               bool bias = false,
                               const infinicore::DataType &dtype = infinicore::DataType::F32,
                               const infinicore::Device &device = infinicore::Device(),
                               engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    explicit QKVParallelLinear(size_t hidden_size,
                               size_t q_dim, size_t k_dim, size_t v_dim,
                               size_t num_q_head, size_t num_k_head, size_t num_v_head,
                               bool q_bias, bool k_bias, bool v_bias,
                               std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
                               const infinicore::DataType &dtype = infinicore::DataType::F32,
                               const infinicore::Device &device = infinicore::Device(),
                               engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

    // A more common case where all heads have the same dimension
    explicit QKVParallelLinear(size_t hidden_size,
                               size_t head_dim,
                               size_t num_q_head, size_t num_kv_head,
                               std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
                               bool bias = false,
                               const infinicore::DataType &dtype = infinicore::DataType::F32,
                               const infinicore::Device &device = infinicore::Device(),
                               engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

46
47
48
49
50
51
52
    std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
    forward_split(infinicore::Tensor &input);

    infinicore::nn::Parameter get_q_weight() const;
    infinicore::nn::Parameter get_k_weight() const;
    infinicore::nn::Parameter get_v_weight() const;

53
54
55
56
57
58
59
60
    infinicore::nn::Parameter get_q_weight_scale() const;
    infinicore::nn::Parameter get_k_weight_scale() const;
    infinicore::nn::Parameter get_v_weight_scale() const;

    infinicore::nn::Parameter get_q_weight_zeros() const;
    infinicore::nn::Parameter get_k_weight_zeros() const;
    infinicore::nn::Parameter get_v_weight_zeros() const;

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    // For computing the packing factor in awq quantization:
    // Returns the number of low-bit elements packed into a single high-bit container element.
    // For example: int4 → int32 yields a packing factor of 8 (32 bits / 4 bits = 8 int4 values per int32).
    infinicore::nn::Parameter get_q_weight_awq(int scaling_factor) const;
    infinicore::nn::Parameter get_k_weight_awq(int scaling_factor) const;
    infinicore::nn::Parameter get_v_weight_awq(int scaling_factor) const;

    infinicore::nn::Parameter get_q_weight_scale_awq(int scaling_factor) const;
    infinicore::nn::Parameter get_k_weight_scale_awq(int scaling_factor) const;
    infinicore::nn::Parameter get_v_weight_scale_awq(int scaling_factor) const;

    infinicore::nn::Parameter get_q_weight_zeros_awq(int scaling_factor) const;
    infinicore::nn::Parameter get_k_weight_zeros_awq(int scaling_factor) const;
    infinicore::nn::Parameter get_v_weight_zeros_awq(int scaling_factor) const;

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    infinicore::nn::Parameter get_q_bias() const;
    infinicore::nn::Parameter get_k_bias() const;
    infinicore::nn::Parameter get_v_bias() const;

    bool has_q_bias() const;
    bool has_k_bias() const;
    bool has_v_bias() const;

private:
    size_t q_dim_;
    size_t k_dim_;
    size_t v_dim_;
    size_t num_q_head_;
    size_t num_k_head_;
    size_t num_v_head_;
    bool q_bias_;
    bool k_bias_;
    bool v_bias_;
    size_t q_out_size_; // num_q_head * q_dim / tp_size
    size_t k_out_size_; // num_k_head * k_dim / tp_size
    size_t v_out_size_; // num_v_head * v_dim / tp_size
};

class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {
public:
101
102
103
104
105
106
107
108
109
110
111
112
    /**
     * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
     *
     * ⚠️ DEVELOPMENT POLICY:
     *   - NO new development or feature additions permitted on this interface
     *   - Only critical bug fixes (security/stability) allowed until removal
     *   - All new code MUST migrate to the polymorphic overload below
     *
     * Replacement: Use the polymorphic overload of this same function name with updated signature
     * Reason: Legacy signature lacks support for dynamic quantization modes.
     * Removal target: v0.2.0 (Q2 2026)
     */
113
114
115
116
117
118
119
120
    GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias = false,
                         const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
                         engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

    GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
                         const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
                         engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

121
122
123
124
125
126
127
128
129
130
131
    GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
                         bool bias = false,
                         const infinicore::DataType &dtype = infinicore::DataType::F32,
                         const infinicore::Device &device = infinicore::Device(),
                         engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

    GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
                         std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
                         const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
                         engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

132
133
134
135
    std::tuple<infinicore::Tensor, infinicore::Tensor> forward_split(infinicore::Tensor &input);

    infinicore::nn::Parameter get_gate_weight() const;

136
137
138
139
    infinicore::nn::Parameter get_gate_weight_scale() const;

    infinicore::nn::Parameter get_gate_weight_zeros() const;

140
141
142
143
    infinicore::nn::Parameter get_gate_bias() const;

    infinicore::nn::Parameter get_up_weight() const;

144
145
146
147
    infinicore::nn::Parameter get_up_weight_scale() const;

    infinicore::nn::Parameter get_up_weight_zeros() const;

148
149
    infinicore::nn::Parameter get_up_bias() const;

150
151
152
153
154
155
156
157
158
159
160
161
    infinicore::nn::Parameter get_gate_weight_awq() const;

    infinicore::nn::Parameter get_up_weight_awq() const;

    infinicore::nn::Parameter get_up_weight_scale_awq() const;

    infinicore::nn::Parameter get_up_weight_zeros_awq() const;

    infinicore::nn::Parameter get_gate_weight_scale_awq() const;

    infinicore::nn::Parameter get_gate_weight_zeros_awq() const;

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    bool has_gate_bias() const;

    bool has_up_bias() const;

private:
    bool gate_bias_;
    bool up_bias_;
};

#define INFINILM_QKV_LINEAR_INIT(name, q_name, k_name, v_name, ...)                     \
    name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__);                 \
    this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight()); \
    this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight()); \
    this->register_parameter(std::string(v_name) + ".weight", name##_->get_v_weight()); \
    if (name##_->has_q_bias())                                                          \
        this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
    if (name##_->has_k_bias())                                                          \
        this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \
    if (name##_->has_v_bias())                                                          \
        this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());

#define INFINILM_GATE_UP_LINEAR_INIT(name, gate_name, up_name, ...)                           \
    name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__);                    \
    this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight()); \
    this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight());     \
    if (name##_->has_gate_bias())                                                             \
        this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
    if (name##_->has_up_bias())                                                               \
        this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
// ========================= QKV Quantization ==================================
#define INFINILM_QKV_LINEAR_W8A8_INIT(name, q_name, k_name, v_name, ...)                            \
    name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__);                             \
    this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight());             \
    this->register_parameter(std::string(q_name) + ".weight_scale", name##_->get_q_weight_scale()); \
    this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight());             \
    this->register_parameter(std::string(k_name) + ".weight_scale", name##_->get_k_weight_scale()); \
    this->register_parameter(std::string(v_name) + ".weight", name##_->get_v_weight());             \
    this->register_parameter(std::string(v_name) + ".weight_scale", name##_->get_v_weight_scale()); \
    if (name##_->has_q_bias())                                                                      \
        this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias());             \
    if (name##_->has_k_bias())                                                                      \
        this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias());             \
    if (name##_->has_v_bias())                                                                      \
        this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());

208
209
#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...)                                 \
    name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__);                                      \
210
    auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(name##_->get_quantization());     \
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    int packing_num = awq_ptr->get_packing_num();                                                            \
    this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight_awq(packing_num));      \
    this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros_awq(packing_num)); \
    this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale_awq(1));           \
    this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight_awq(packing_num));      \
    this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros_awq(packing_num)); \
    this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale_awq(1));           \
    this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight_awq(packing_num));      \
    this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros_awq(packing_num)); \
    this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale_awq(1));           \
    if (name##_->has_q_bias())                                                                               \
        this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias());                      \
    if (name##_->has_k_bias())                                                                               \
        this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias());                      \
    if (name##_->has_v_bias())                                                                               \
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());

// ========================= Gate-Up Quantization ==============================
#define INFINILM_GATE_UP_LINEAR_W8A8_INIT(name, gate_name, up_name, ...)                                  \
    name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__);                                \
    this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight());             \
    this->register_parameter(std::string(gate_name) + ".weight_scale", name##_->get_gate_weight_scale()); \
    this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight());                 \
    this->register_parameter(std::string(up_name) + ".weight_scale", name##_->get_up_weight_scale());     \
    if (name##_->has_gate_bias())                                                                         \
        this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias());             \
    if (name##_->has_up_bias())                                                                           \
        this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());

240
241
242
243
244
245
246
247
248
249
250
#define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...)                            \
    name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__);                              \
    this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight_awq());      \
    this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros_awq()); \
    this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale_awq()); \
    this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight_awq());          \
    this->register_parameter(std::string(up_name) + ".qzeros", name##_->get_up_weight_zeros_awq());     \
    this->register_parameter(std::string(up_name) + ".scales", name##_->get_up_weight_scale_awq());     \
    if (name##_->has_gate_bias())                                                                       \
        this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias());           \
    if (name##_->has_up_bias())                                                                         \
251
        this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());
252
} // namespace infinilm::layers