"tests/data/streamResultNeo.log" did not exist on "62a2913497a866754ae96d57ef445d8cec6e89b2"
fused_linear.cpp 14 KB
Newer Older
1
2
3
4
5
6
7
8
#include "fused_linear.hpp"

#include <spdlog/spdlog.h>

namespace infinilm::layers {
// ---------------------------------------------------------
// QKV Parallel Linear
// ---------------------------------------------------------
9
10
11
12
13
14
15
16
17
18
19
20
/**
 * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
 *
 * ⚠️ DEVELOPMENT POLICY:
 *   - NO new development or feature additions permitted on this interface
 *   - Only critical bug fixes (security/stability) allowed until removal
 *   - All new code MUST migrate to the polymorphic overload below
 *
 * Replacement: Use the polymorphic overload of this same function name with updated signature
 * Reason: Legacy signature lacks support for dynamic quantization modes.
 * Removal target: v0.2.0 (Q2 2026)
 */
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
                                     size_t head_dim,
                                     size_t num_q_head,
                                     size_t num_kv_head,
                                     bool bias,
                                     const infinicore::DataType &dtype,
                                     const infinicore::Device &device,
                                     engine::distributed::RankInfo rank_info)
    : QKVParallelLinear(hidden_size,
                        head_dim, head_dim, head_dim,
                        num_q_head, num_kv_head, num_kv_head,
                        bias, bias, bias,
                        dtype, device, rank_info) {}

QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
                                     size_t q_dim, size_t k_dim, size_t v_dim,
                                     size_t num_q_head, size_t num_k_head, size_t num_v_head,
                                     bool q_bias, bool k_bias, bool v_bias,
                                     const infinicore::DataType &dtype,
                                     const infinicore::Device &device,
                                     engine::distributed::RankInfo rank_info)
    : infinicore::nn::ColumnParallelLinear(
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
          hidden_size,
          num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim,
          (q_bias || k_bias || v_bias),
          dtype,
          device,
          rank_info.tp_rank,
          rank_info.tp_size),
      q_dim_(q_dim),
      k_dim_(k_dim),
      v_dim_(v_dim),
      num_q_head_(num_q_head),
      num_k_head_(num_k_head),
      num_v_head_(num_v_head),
      q_bias_(q_bias),
      k_bias_(k_bias),
      v_bias_(v_bias) {
    if (num_q_head % tp_size_ != 0 || num_k_head % tp_size_ != 0 || num_v_head % tp_size_ != 0) {
        throw std::runtime_error("QKVParallelLinear: num_[q|k|v]_head must be divisible by tp_size");
    }

    if ((q_bias_ != k_bias_) || (k_bias_ != v_bias_)) {
        throw std::runtime_error("q_bias, k_bias, v_bias must all match");
    }

    q_out_size_ = num_q_head_ * q_dim_ / tp_size_;
    k_out_size_ = num_k_head_ * k_dim_ / tp_size_;
    v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
}

QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
                                     size_t head_dim,
                                     size_t num_q_head,
                                     size_t num_kv_head,
                                     std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
                                     bool bias,
                                     const infinicore::DataType &dtype,
                                     const infinicore::Device &device,
                                     engine::distributed::RankInfo rank_info)
    : QKVParallelLinear(hidden_size,
                        head_dim, head_dim, head_dim,
                        num_q_head, num_kv_head, num_kv_head,
                        bias, bias, bias,
                        quantization,
                        dtype, device, rank_info) {}

QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
                                     size_t q_dim, size_t k_dim, size_t v_dim,
                                     size_t num_q_head, size_t num_k_head, size_t num_v_head,
                                     bool q_bias, bool k_bias, bool v_bias,
                                     std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
                                     const infinicore::DataType &dtype,
                                     const infinicore::Device &device,
                                     engine::distributed::RankInfo rank_info)
    : infinicore::nn::ColumnParallelLinear(
          hidden_size,
          num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim,
          quantization,
          (q_bias || k_bias || v_bias),
          dtype,
          device,
          rank_info.tp_rank,
          rank_info.tp_size),
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
      q_dim_(q_dim),
      k_dim_(k_dim),
      v_dim_(v_dim),
      num_q_head_(num_q_head),
      num_k_head_(num_k_head),
      num_v_head_(num_v_head),
      q_bias_(q_bias),
      k_bias_(k_bias),
      v_bias_(v_bias) {
    if (num_q_head % tp_size_ != 0 || num_k_head % tp_size_ != 0 || num_v_head % tp_size_ != 0) {
        throw std::runtime_error("QKVParallelLinear: num_[q|k|v]_head must be divisible by tp_size");
    }

    if ((q_bias_ != k_bias_) || (k_bias_ != v_bias_)) {
        throw std::runtime_error("q_bias, k_bias, v_bias must all match");
    }

    q_out_size_ = num_q_head_ * q_dim_ / tp_size_;
    k_out_size_ = num_k_head_ * k_dim_ / tp_size_;
    v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
}

std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
QKVParallelLinear::forward_split(infinicore::Tensor &input) {
    auto output = this->forward(input);

    auto q_out = output->narrow({{2, 0, q_out_size_}});
    auto k_out = output->narrow({{2, q_out_size_, k_out_size_}});
    auto v_out = output->narrow({{2, q_out_size_ + k_out_size_, v_out_size_}});

    return std::make_tuple(q_out, k_out, v_out);
}

infinicore::nn::Parameter QKVParallelLinear::get_q_weight() const {
    return infinicore::nn::Parameter(
        weight_->narrow({{0, 0, q_out_size_}}),
        0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_k_weight() const {
    return infinicore::nn::Parameter(
        weight_->narrow({{0, q_out_size_, k_out_size_}}),
        0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_v_weight() const {
    return infinicore::nn::Parameter(
        weight_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
        0, tp_rank_, tp_size_);
}

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_scale() const {
    return infinicore::nn::Parameter(
        weight_scale_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_k_weight_scale() const {
    return infinicore::nn::Parameter(
        weight_scale_->narrow({{0, q_out_size_, k_out_size_}}),
        0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale() const {
    return infinicore::nn::Parameter(
        weight_scale_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
        0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_q_weight_zeros() const {
    return infinicore::nn::Parameter(
        weight_zeros_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_k_weight_zeros() const {
    return infinicore::nn::Parameter(
        weight_zeros_->narrow({{0, q_out_size_, k_out_size_}}),
        0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_v_weight_zeros() const {
    return infinicore::nn::Parameter(
        weight_zeros_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
        0, tp_rank_, tp_size_);
}

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
infinicore::nn::Parameter QKVParallelLinear::get_q_bias() const {
    if (!q_bias_) {
        return infinicore::nn::Parameter();
    }
    return infinicore::nn::Parameter(
        bias_->narrow({{0, 0, q_out_size_}}),
        0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_k_bias() const {
    if (!k_bias_) {
        return infinicore::nn::Parameter();
    }
    return infinicore::nn::Parameter(
        bias_->narrow({{0, q_out_size_, k_out_size_}}),
        0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter QKVParallelLinear::get_v_bias() const {
    if (!v_bias_) {
        return infinicore::nn::Parameter();
    }
    return infinicore::nn::Parameter(
        bias_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
        0, tp_rank_, tp_size_);
}

bool QKVParallelLinear::has_q_bias() const { return q_bias_; }
bool QKVParallelLinear::has_k_bias() const { return k_bias_; }
bool QKVParallelLinear::has_v_bias() const { return v_bias_; }

// ---------------------------------------------------------
// Gate-Up Parallel Linear
// ---------------------------------------------------------
224
225
226
227
228
229
230
231
232
233
234
235
/**
 * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
 *
 * ⚠️ DEVELOPMENT POLICY:
 *   - NO new development or feature additions permitted on this interface
 *   - Only critical bug fixes (security/stability) allowed until removal
 *   - All new code MUST migrate to the polymorphic overload below
 *
 * Replacement: Use the polymorphic overload of this same function name with updated signature
 * Reason: Legacy signature lacks support for dynamic quantization modes.
 * Removal target: v0.2.0 (Q2 2026)
 */
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias,
                                           const infinicore::DataType &dtype, const infinicore::Device &device,
                                           engine::distributed::RankInfo rank_info)
    : GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, dtype, device, rank_info) {
}

GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
                                           const infinicore::DataType &dtype, const infinicore::Device &device,
                                           engine::distributed::RankInfo rank_info)
    : infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) {
    if (gate_bias_ != up_bias_) {
        throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time");
    }
}

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias,
                                           const infinicore::DataType &dtype, const infinicore::Device &device,
                                           engine::distributed::RankInfo rank_info)
    : GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, quantization, dtype, device, rank_info) {
}

GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
                                           std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
                                           const infinicore::DataType &dtype, const infinicore::Device &device,
                                           engine::distributed::RankInfo rank_info)
    : infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, quantization, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) {
    if (gate_bias_ != up_bias_) {
        throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time");
    }
}

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
std::tuple<infinicore::Tensor, infinicore::Tensor> GateUpParallelLinear::forward_split(infinicore::Tensor &input) {
    auto output = this->forward(input);
    auto cols = output->shape()[2];
    auto gate_output = output->narrow({{2, 0, cols / 2}});
    auto up_output = output->narrow({{2, cols / 2, cols / 2}});
    return std::make_tuple(gate_output, up_output);
}

infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight() const {
    return infinicore::nn::Parameter(weight_->narrow({{0, 0, weight_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_gate_bias() const {
    if (!gate_bias_) {
        return infinicore::nn::Parameter();
    } else {
        return infinicore::nn::Parameter(bias_->narrow({{0, 0, bias_->size(0) / 2}}), 0, tp_rank_, tp_size_);
    }
}

infinicore::nn::Parameter GateUpParallelLinear::get_up_weight() const {
    return infinicore::nn::Parameter(weight_->narrow({{0, weight_->size(0) / 2, weight_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_up_bias() const {
    if (!up_bias_) {
        return infinicore::nn::Parameter();
    } else {
        return infinicore::nn::Parameter(bias_->narrow({{0, bias_->size(0) / 2, bias_->size(0) / 2}}),
                                         0, tp_rank_, tp_size_);
    }
}

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_scale() const {
    return infinicore::nn::Parameter(weight_scale_->narrow({{0, 0, weight_scale_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_scale() const {
    return infinicore::nn::Parameter(weight_scale_->narrow({{0, weight_scale_->size(0) / 2, weight_scale_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_zeros() const {
    return infinicore::nn::Parameter(weight_zeros_->narrow({{0, 0, weight_zeros_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}

infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_zeros() const {
    return infinicore::nn::Parameter(weight_zeros_->narrow({{0, weight_zeros_->size(0) / 2, weight_zeros_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}

316
317
318
319
320
321
322
323
bool GateUpParallelLinear::has_gate_bias() const {
    return gate_bias_;
}

bool GateUpParallelLinear::has_up_bias() const {
    return up_bias_;
}
} // namespace infinilm::layers