llama_attention.cpp 19 KB
Newer Older
1
#include "llama_attention.hpp"
2

3
#include "../../utils.hpp"
4
5
6
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
7
#include "infinicore/ops/mha_varlen.hpp"
8
#include "infinicore/ops/mul.hpp"
9

10
#include <algorithm>
11
12
#include <cmath>
#include <cstring>
13
#include <optional>
14
15
#include <spdlog/spdlog.h>
#include <stdexcept>
Your Name's avatar
Your Name committed
16
#include <vector>
17
18
19

namespace infinilm::models::llama {

20
21
22
23
24
25
26
27
28
29
30
31
/**
 * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
 *
 * ⚠️ DEVELOPMENT POLICY:
 *   - NO new development or feature additions permitted on this interface
 *   - Only critical bug fixes (security/stability) allowed until removal
 *   - All new code MUST migrate to the polymorphic overload below
 *
 * Replacement: Use the polymorphic overload of this same function name with updated signature
 * Reason: Legacy signature lacks support for dynamic quantization modes.
 * Removal target: v0.2.0 (Q2 2026)
 */
Your Name's avatar
Your Name committed
32
33
LlamaAttention::LlamaAttention(const LlamaConfig &config,
                               const infinicore::Device &device,
Ceng's avatar
Ceng committed
34
                               size_t layer_idx,
35
36
                               engine::distributed::RankInfo rank_info,
                               backends::AttentionBackend attention_backend)
Ceng's avatar
Ceng committed
37
38
    : layer_idx_(layer_idx),
      hidden_size_(config.hidden_size),
39
40
41
42
      num_attention_heads_(config.num_attention_heads),
      num_key_value_heads_(config.num_key_value_heads),
      head_dim_(config.head_dim),
      kv_dim_(config.kv_dim()),
Ceng's avatar
Ceng committed
43
44
      use_bias_(config.attention_bias),
      use_output_bias_(config.attention_output_bias),
wangpengcheng's avatar
wangpengcheng committed
45
      use_qk_norm_(config.qk_norm),
46
47
48
      max_position_embeddings_(config.max_position_embeddings),
      rank_info_(rank_info),
      attention_backend_(attention_backend) {
49
    const auto &dtype{config.dtype};
Your Name's avatar
Your Name committed
50
51
52
53
54
55
56
57
58
59
60
61
62

    int tp_rank = rank_info.tp_rank;
    int tp_size = rank_info.tp_size;

    int num_attention_heads = config.num_attention_heads;
    int num_key_value_heads = config.num_key_value_heads;

    if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) {
        this->num_attention_heads_ = num_attention_heads / tp_size;
        this->num_key_value_heads_ = num_key_value_heads / tp_size;
    } else {
        throw std::runtime_error("num_attention_heads / tp_size error.");
    }
63
    scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_));
Your Name's avatar
Your Name committed
64

65
    // Initialize projection layers
66
67
    INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_,
                             dtype, device, rank_info);
Ceng's avatar
Ceng committed
68
    // Output projection uses attention_output_bias (can be different from qkv)
wangpengcheng's avatar
wangpengcheng committed
69
    INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads * head_dim_, hidden_size_, use_output_bias_,
Your Name's avatar
Your Name committed
70
                              dtype, device, tp_rank, tp_size, rank_info.comm);
wangpengcheng's avatar
wangpengcheng committed
71
72
73
74
75
76

    // Initialize qk RMSNorm
    if (use_qk_norm_) {
        INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, config.rms_norm_eps, dtype, device);
        INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, config.rms_norm_eps, dtype, device);
    }
77
78
}

79
80
81
LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
                               const infinicore::Device &device,
                               size_t layer_idx,
82
83
                               engine::distributed::RankInfo rank_info,
                               backends::AttentionBackend attention_backend)
84
85
86
87
88
89
90
91
92
93
    : model_config_(model_config),
      layer_idx_(layer_idx),
      hidden_size_(model_config->get<size_t>("hidden_size")),
      num_attention_heads_(model_config->get<size_t>("num_attention_heads")),
      num_key_value_heads_(model_config->get<size_t>("num_key_value_heads")),
      head_dim_(model_config->get_head_dim()),
      kv_dim_(model_config->get_kv_dim()),
      use_bias_(model_config->get_or<bool>("attention_bias", true)),
      use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)),
      max_position_embeddings_(model_config->get<size_t>("max_position_embeddings")),
94
95
      rank_info_(rank_info),
      attention_backend_(attention_backend) {
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    const auto &dtype{model_config_->get_dtype()};

    int tp_rank = rank_info.tp_rank;
    int tp_size = rank_info.tp_size;

    int num_attention_heads = model_config_->get<size_t>("num_attention_heads");
    int num_key_value_heads = model_config_->get<size_t>("num_key_value_heads");

    if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) {
        this->num_attention_heads_ = num_attention_heads / tp_size;
        this->num_key_value_heads_ = num_key_value_heads / tp_size;
    } else {
        throw std::runtime_error("num_attention_heads / tp_size error.");
    }
    scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_));

    auto quant_scheme = this->model_config_->get_quant_scheme();
    switch (quant_scheme) {
    case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8:
        INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
                                      dtype, device, rank_info);
        INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
                                  dtype, device, tp_rank, tp_size, rank_info.comm);
        break;

    case infinicore::quantization::QuantScheme::AWQ_W4A16:
        INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
                                          dtype, device, rank_info);
        INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
                                  dtype, device, tp_rank, tp_size, rank_info.comm);
        break;
    default:
        INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
                                 dtype, device, rank_info);
        INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
                                  dtype, device, tp_rank, tp_size, rank_info.comm);
        break;
    }
    if (model_config_->get<std::string>("model_type") == "qwen3") {
        INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
        INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
    }
}

140
141
142
infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
                                            const infinicore::Tensor &position_ids,
                                            std::shared_ptr<infinilm::cache::Cache> kv_cache,
143
144
                                            std::optional<infinicore::Tensor> past_sequence_lengths,
                                            std::optional<infinicore::Tensor> total_sequence_lengths) const {
145
146
147
148
149
150
151
    // Input shape: [batch, seq_len, hidden_size]
    auto hidden_states_mutable = hidden_states;
    auto shape = hidden_states->shape();
    size_t batch_size = shape[0];
    size_t seq_len = shape[1];

    // 1. Project Q, K, V
152
    auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
153

154
    if (use_qk_norm_ || model_config_->get_or<std::string>("model_type", "None") == "qwen3") {
wangpengcheng's avatar
wangpengcheng committed
155
156
157
158
        q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_}));
        k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_}));
    }

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    // 2. Reshape for multi-head attention
    // Reshape Q, K, V to include batch dimension
    // Python: query_states = self.q_proj(hidden_states).view(querys_shape)
    // The view operation requires the tensor to be contiguous in the required dimensions
    auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_});
    auto k_reshaped = k->view({batch_size, seq_len, num_key_value_heads_, head_dim_});
    auto v_reshaped = v->view({batch_size, seq_len, num_key_value_heads_, head_dim_});

    // 3. Prepare position_ids for RoPE - align with Python pattern
    // Python: bs, num = pos_ids.shape; pos_ids = pos_ids.view((bs * num,))
    auto pos_shape = position_ids->shape();
    infinicore::Tensor pos_ids_for_rope = position_ids;
    if (pos_shape.size() == 2) {
        auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
        pos_ids_for_rope = pos_narrowed->contiguous()->view({pos_shape[1]});
    } else if (pos_shape.size() == 1) {
        pos_ids_for_rope = position_ids->contiguous();
    } else {
        throw std::runtime_error("Unexpected position_ids shape");
    }

wooway777's avatar
wooway777 committed
180
181
182
183
184
185
    // 4. Apply RoPE to Q and K
    auto q_rope = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q_reshaped->dtype(), q_reshaped->device())->permute({0, 2, 1, 3});
    rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
    rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true);   // [bs, seq_len, n_kv_head, head_dim]

    // 5. Prepare KV caches
186
187
    // Convert to [batch, n_head, seq_len, head_dim] for cache
    // Ensure contiguous after permute for F16 compatibility with cache operations
wooway777's avatar
wooway777 committed
188
189
190
    q_reshaped = q_rope->permute({0, 2, 1, 3});          // [bs, n_q_head, seq_len, head_dim]
    auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
    auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
191
192
    infinicore::Tensor k_total;                          // [bs, n_kv_head, max_seq_len, head_dim]
    infinicore::Tensor v_total;                          // [bs, n_kv_head, max_seq_len, head_dim]
193
194
195
196
    if (kv_cache == nullptr) {
        k_total = k_permuted;
        v_total = v_permuted;
    } else if (auto static_kv_cache = std::dynamic_pointer_cast<cache::StaticKVCache>(kv_cache)) {
197
        auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, past_sequence_lengths.value());
198
199
200
        k_total = k_total_tmp;
        v_total = v_total_tmp;
    } else {
PanZezhong's avatar
PanZezhong committed
201
        throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
202
    }
203

204
    infinicore::Tensor attn_output;
205
206
    if (false) {
        // experimental nineoothed flash attention
207
208
209
210
211
        attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scaling_, true);
        attn_output = attn_output->permute({0, 2, 1, 3})
                          ->contiguous()
                          ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
    } else {
212
        size_t total_seq_len = reinterpret_cast<int32_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
213
214
215
216
217
218
219
220
        k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
        v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]

        // 6. Compute attention
        size_t ngroup = num_attention_heads_ / num_key_value_heads_;
        auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
        auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
        auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
221

222
        auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
223

224
        auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len]
225

226
227
        auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
        infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
228

229
        auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
230

231
232
233
234
235
        attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
                          ->permute({0, 2, 1, 3})
                          ->contiguous()
                          ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
    }
236

237
    auto output = o_proj_->forward(attn_output);
238
239
240
241

    return output;
}

242
243
244
infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidden_states,
                                                  const infinicore::Tensor &position_ids,
                                                  std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
245
                                                  std::optional<infinicore::Tensor> total_sequence_lengths,
246
                                                  std::optional<infinicore::Tensor> input_offsets,
247
                                                  std::optional<infinicore::Tensor> cu_seqlens,
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                                                  std::optional<infinicore::Tensor> block_tables,
                                                  std::optional<infinicore::Tensor> slot_mapping) const {
    ASSERT(block_tables.has_value());
    ASSERT(slot_mapping.has_value());

    // Input shape: [batch, seq_len, hidden_size]
    auto hidden_states_mutable = hidden_states;
    auto shape = hidden_states->shape();
    size_t batch_size = shape[0];
    size_t seq_len = shape[1];

    // Only support batchsize==1, all requests should be flattened along seqlen dimension
    ASSERT_EQ(batch_size, 1);
    // Decode only if total_len == num_requests
262
    bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);
263
264
265
266
267
268
269
270
271
272
273
274
275

    // 1. Project Q, K, V
    auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);

    // 2. Reshape for multi-head attention

    // Reshape Q, K, V to include batch dimension
    // Python: query_states = self.q_proj(hidden_states).view(querys_shape)
    // The view operation requires the tensor to be contiguous in the required dimensions
    auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_});
    auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_});
    auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_});

276
    if (use_qk_norm_ || model_config_->get_or<std::string>("model_type", "None") == "qwen3") {
wangpengcheng's avatar
wangpengcheng committed
277
278
279
280
        q_reshaped = q_norm_->forward(q_reshaped);
        k_reshaped = k_norm_->forward(k_reshaped);
    }

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    // 3. Prepare position_ids for RoPE - align with Python pattern
    auto pos_shape = position_ids->shape();
    infinicore::Tensor pos_ids_for_rope = position_ids;
    if (pos_shape.size() == 2) {
        auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
        pos_ids_for_rope = pos_narrowed->view({pos_shape[1]});
    } else if (pos_shape.size() == 1) {
        pos_ids_for_rope = position_ids;
    } else {
        throw std::runtime_error("Unexpected position_ids shape");
    }

    // 4. Apply RoPE to Q and K
    rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_q_head, head_dim]
    rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim]

    //  5. Prepare KV caches
    //  Ensure contiguous after permute for F16 compatibility with cache operations
    auto [k_total, v_total] = paged_kv_cache->update(layer_idx_,
                                                     k_reshaped,
                                                     v_reshaped,
                                                     slot_mapping.value());

    // 6. Compute attention
    infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());

307
308
309
    if (is_prefill) {
        if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
            infinicore::op::mha_varlen_(
310
311
                attn_output,
                q_reshaped,
312
313
                k_total->permute({0, 2, 1, 3}),
                v_total->permute({0, 2, 1, 3}),
314
                input_offsets.value(),
315
316
317
318
                cu_seqlens.value(),
                block_tables.value(),
                max_position_embeddings_,
                max_position_embeddings_,
319
320
321
                std::nullopt,
                scaling_);
        } else {
322
            infinicore::op::paged_attention_prefill_(
323
324
325
326
327
328
                attn_output,
                q_reshaped,
                k_total,
                v_total,
                block_tables.value(),
                total_sequence_lengths.value(),
329
                input_offsets.value(),
330
331
332
                std::nullopt,
                scaling_);
        }
333
334
335
336
337
338
339
340
341
342
    } else {
        infinicore::op::paged_attention_(
            attn_output,
            q_reshaped,
            k_total,
            v_total,
            block_tables.value(),
            total_sequence_lengths.value(),
            std::nullopt,
            scaling_);
343
    }
344
345

    // 7. Project output
346
347
    attn_output
        = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
348
349
350
351
352
353
    return o_proj_->forward(attn_output);
}

infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
                                           const infinicore::Tensor &position_ids,
                                           std::shared_ptr<cache::Cache> kv_cache,
354
355
                                           std::optional<infinicore::Tensor> past_sequence_lengths,
                                           std::optional<infinicore::Tensor> total_sequence_lengths,
356
                                           std::optional<infinicore::Tensor> input_offsets,
357
                                           std::optional<infinicore::Tensor> cu_seqlens,
358
359
360
361
362
363
364
365
                                           std::optional<infinicore::Tensor> block_tables,
                                           std::optional<infinicore::Tensor> slot_mapping) const {
    if (!rotary_emb_) {
        throw std::runtime_error("LlamaAttention: rotary_emb not configured");
    }

    infinicore::Tensor output;
    if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
366
        output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping);
367
368
    } else {

369
        output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths);
370
371
372
373
    }
    return output;
}

374
375
376
377
378
void LlamaAttention::set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) {
    rotary_emb_ = rotary_emb;
}

} // namespace infinilm::models::llama