llama_attention.cpp 21.4 KB
Newer Older
1
#include "llama_attention.hpp"
2

3
#include "../../utils.hpp"
4
5
6
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
suss's avatar
suss committed
7
#include "infinicore/ops/mha_kvcache.hpp"
8
#include "infinicore/ops/mha_varlen.hpp"
9
#include "infinicore/ops/mul.hpp"
10

11
#include <algorithm>
12
13
#include <cmath>
#include <cstring>
14
#include <optional>
15
16
#include <spdlog/spdlog.h>
#include <stdexcept>
Your Name's avatar
Your Name committed
17
#include <vector>
18
19
20

namespace infinilm::models::llama {

21
22
23
24
25
26
27
28
29
30
31
32
/**
 * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
 *
 * ⚠️ DEVELOPMENT POLICY:
 *   - NO new development or feature additions permitted on this interface
 *   - Only critical bug fixes (security/stability) allowed until removal
 *   - All new code MUST migrate to the polymorphic overload below
 *
 * Replacement: Use the polymorphic overload of this same function name with updated signature
 * Reason: Legacy signature lacks support for dynamic quantization modes.
 * Removal target: v0.2.0 (Q2 2026)
 */
Your Name's avatar
Your Name committed
33
34
LlamaAttention::LlamaAttention(const LlamaConfig &config,
                               const infinicore::Device &device,
Ceng's avatar
Ceng committed
35
                               size_t layer_idx,
36
37
                               engine::distributed::RankInfo rank_info,
                               backends::AttentionBackend attention_backend)
Ceng's avatar
Ceng committed
38
39
    : layer_idx_(layer_idx),
      hidden_size_(config.hidden_size),
40
41
42
43
      num_attention_heads_(config.num_attention_heads),
      num_key_value_heads_(config.num_key_value_heads),
      head_dim_(config.head_dim),
      kv_dim_(config.kv_dim()),
Ceng's avatar
Ceng committed
44
45
      use_bias_(config.attention_bias),
      use_output_bias_(config.attention_output_bias),
wangpengcheng's avatar
wangpengcheng committed
46
      use_qk_norm_(config.qk_norm),
47
48
49
      max_position_embeddings_(config.max_position_embeddings),
      rank_info_(rank_info),
      attention_backend_(attention_backend) {
50
    const auto &dtype{config.dtype};
Your Name's avatar
Your Name committed
51
52
53
54
55
56
57
58
59
60
61
62
63

    int tp_rank = rank_info.tp_rank;
    int tp_size = rank_info.tp_size;

    int num_attention_heads = config.num_attention_heads;
    int num_key_value_heads = config.num_key_value_heads;

    if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) {
        this->num_attention_heads_ = num_attention_heads / tp_size;
        this->num_key_value_heads_ = num_key_value_heads / tp_size;
    } else {
        throw std::runtime_error("num_attention_heads / tp_size error.");
    }
64
    scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_));
Your Name's avatar
Your Name committed
65

66
    // Initialize projection layers
67
68
    INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_,
                             dtype, device, rank_info);
Ceng's avatar
Ceng committed
69
    // Output projection uses attention_output_bias (can be different from qkv)
wangpengcheng's avatar
wangpengcheng committed
70
    INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads * head_dim_, hidden_size_, use_output_bias_,
Your Name's avatar
Your Name committed
71
                              dtype, device, tp_rank, tp_size, rank_info.comm);
wangpengcheng's avatar
wangpengcheng committed
72
73
74
75
76
77

    // Initialize qk RMSNorm
    if (use_qk_norm_) {
        INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, config.rms_norm_eps, dtype, device);
        INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, config.rms_norm_eps, dtype, device);
    }
78
79
}

80
81
82
LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
                               const infinicore::Device &device,
                               size_t layer_idx,
83
84
                               engine::distributed::RankInfo rank_info,
                               backends::AttentionBackend attention_backend)
85
86
87
88
89
90
91
92
93
94
    : model_config_(model_config),
      layer_idx_(layer_idx),
      hidden_size_(model_config->get<size_t>("hidden_size")),
      num_attention_heads_(model_config->get<size_t>("num_attention_heads")),
      num_key_value_heads_(model_config->get<size_t>("num_key_value_heads")),
      head_dim_(model_config->get_head_dim()),
      kv_dim_(model_config->get_kv_dim()),
      use_bias_(model_config->get_or<bool>("attention_bias", true)),
      use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)),
      max_position_embeddings_(model_config->get<size_t>("max_position_embeddings")),
95
96
      rank_info_(rank_info),
      attention_backend_(attention_backend) {
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    const auto &dtype{model_config_->get_dtype()};

    int tp_rank = rank_info.tp_rank;
    int tp_size = rank_info.tp_size;

    int num_attention_heads = model_config_->get<size_t>("num_attention_heads");
    int num_key_value_heads = model_config_->get<size_t>("num_key_value_heads");

    if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) {
        this->num_attention_heads_ = num_attention_heads / tp_size;
        this->num_key_value_heads_ = num_key_value_heads / tp_size;
    } else {
        throw std::runtime_error("num_attention_heads / tp_size error.");
    }
    scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_));

    auto quant_scheme = this->model_config_->get_quant_scheme();
    switch (quant_scheme) {
    case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8:
        INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
                                      dtype, device, rank_info);
        INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
                                  dtype, device, tp_rank, tp_size, rank_info.comm);
        break;

122
    case infinicore::quantization::QuantScheme::AWQ_W4A16: {
123
124
125
126
127
        INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
                                          dtype, device, rank_info);
        INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
                                  dtype, device, tp_rank, tp_size, rank_info.comm);
        break;
128
    }
129
130
131
132
133
134
135
136
137
138
139
140
141
    default:
        INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
                                 dtype, device, rank_info);
        INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
                                  dtype, device, tp_rank, tp_size, rank_info.comm);
        break;
    }
    if (model_config_->get<std::string>("model_type") == "qwen3") {
        INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
        INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
    }
}

142
143
144
infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
                                            const infinicore::Tensor &position_ids,
                                            std::shared_ptr<infinilm::cache::Cache> kv_cache,
145
146
                                            std::optional<infinicore::Tensor> past_sequence_lengths,
                                            std::optional<infinicore::Tensor> total_sequence_lengths) const {
147
148
149
150
151
152
153
    // Input shape: [batch, seq_len, hidden_size]
    auto hidden_states_mutable = hidden_states;
    auto shape = hidden_states->shape();
    size_t batch_size = shape[0];
    size_t seq_len = shape[1];

    // 1. Project Q, K, V
154
    auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
155

156
    if (use_qk_norm_ || model_config_->get_or<std::string>("model_type", "None") == "qwen3") {
wangpengcheng's avatar
wangpengcheng committed
157
158
159
160
        q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_}));
        k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_}));
    }

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    // 2. Reshape for multi-head attention
    // Reshape Q, K, V to include batch dimension
    // Python: query_states = self.q_proj(hidden_states).view(querys_shape)
    // The view operation requires the tensor to be contiguous in the required dimensions
    auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_});
    auto k_reshaped = k->view({batch_size, seq_len, num_key_value_heads_, head_dim_});
    auto v_reshaped = v->view({batch_size, seq_len, num_key_value_heads_, head_dim_});

    // 3. Prepare position_ids for RoPE - align with Python pattern
    // Python: bs, num = pos_ids.shape; pos_ids = pos_ids.view((bs * num,))
    auto pos_shape = position_ids->shape();
    infinicore::Tensor pos_ids_for_rope = position_ids;
    if (pos_shape.size() == 2) {
        auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
        pos_ids_for_rope = pos_narrowed->contiguous()->view({pos_shape[1]});
    } else if (pos_shape.size() == 1) {
        pos_ids_for_rope = position_ids->contiguous();
    } else {
        throw std::runtime_error("Unexpected position_ids shape");
    }

wooway777's avatar
wooway777 committed
182
183
    // 4. Apply RoPE to Q and K
    auto q_rope = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q_reshaped->dtype(), q_reshaped->device())->permute({0, 2, 1, 3});
yaoht's avatar
yaoht committed
184
    
wooway777's avatar
wooway777 committed
185
186
187
188
    rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
    rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true);   // [bs, seq_len, n_kv_head, head_dim]

    // 5. Prepare KV caches
189
190
    // Convert to [batch, n_head, seq_len, head_dim] for cache
    // Ensure contiguous after permute for F16 compatibility with cache operations
wooway777's avatar
wooway777 committed
191
192
193
    q_reshaped = q_rope->permute({0, 2, 1, 3});          // [bs, n_q_head, seq_len, head_dim]
    auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
    auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
194
195
    infinicore::Tensor k_total;                          // [bs, n_kv_head, max_seq_len, head_dim]
    infinicore::Tensor v_total;                          // [bs, n_kv_head, max_seq_len, head_dim]
196
197
198
199
    if (kv_cache == nullptr) {
        k_total = k_permuted;
        v_total = v_permuted;
    } else if (auto static_kv_cache = std::dynamic_pointer_cast<cache::StaticKVCache>(kv_cache)) {
200
        auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, past_sequence_lengths.value());
201
202
203
        k_total = k_total_tmp;
        v_total = v_total_tmp;
    } else {
PanZezhong's avatar
PanZezhong committed
204
        throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
205
    }
206

207
    infinicore::Tensor attn_output;
208
209
    if (false) {
        // experimental nineoothed flash attention
210
211
212
213
214
        attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scaling_, true);
        attn_output = attn_output->permute({0, 2, 1, 3})
                          ->contiguous()
                          ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
    } else {
215
        size_t total_seq_len = reinterpret_cast<int32_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
216
217
218
219
220
221
222
223
        k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
        v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]

        // 6. Compute attention
        size_t ngroup = num_attention_heads_ / num_key_value_heads_;
        auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
        auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
        auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
224

225
        auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
226

227
        auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len]
228

229
230
        auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
        infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
231

232
        auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
233

234
235
236
237
238
        attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
                          ->permute({0, 2, 1, 3})
                          ->contiguous()
                          ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
    }
239

240
    auto output = o_proj_->forward(attn_output);
241
242
243
244

    return output;
}

245
246
247
infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidden_states,
                                                  const infinicore::Tensor &position_ids,
                                                  std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
248
                                                  std::optional<infinicore::Tensor> total_sequence_lengths,
249
                                                  std::optional<infinicore::Tensor> input_offsets,
250
                                                  std::optional<infinicore::Tensor> cu_seqlens,
251
                                                  std::optional<infinicore::Tensor> block_tables,
yaoht's avatar
yaoht committed
252
                                                  std::optional<infinicore::Tensor> slot_mapping,int max_seqlen_q, int max_seqlen_k) const {
253
254
255
256
257
258
259
260
261
262
263
264
    ASSERT(block_tables.has_value());
    ASSERT(slot_mapping.has_value());

    // Input shape: [batch, seq_len, hidden_size]
    auto hidden_states_mutable = hidden_states;
    auto shape = hidden_states->shape();
    size_t batch_size = shape[0];
    size_t seq_len = shape[1];

    // Only support batchsize==1, all requests should be flattened along seqlen dimension
    ASSERT_EQ(batch_size, 1);
    // Decode only if total_len == num_requests
265
    bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);
266
267
268
269
270
271
272
273
274
275
276
277
278

    // 1. Project Q, K, V
    auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);

    // 2. Reshape for multi-head attention

    // Reshape Q, K, V to include batch dimension
    // Python: query_states = self.q_proj(hidden_states).view(querys_shape)
    // The view operation requires the tensor to be contiguous in the required dimensions
    auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_});
    auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_});
    auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_});

279
    if (use_qk_norm_ || model_config_->get_or<std::string>("model_type", "None") == "qwen3") {
wangpengcheng's avatar
wangpengcheng committed
280
281
282
283
        q_reshaped = q_norm_->forward(q_reshaped);
        k_reshaped = k_norm_->forward(k_reshaped);
    }

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    // 3. Prepare position_ids for RoPE - align with Python pattern
    auto pos_shape = position_ids->shape();
    infinicore::Tensor pos_ids_for_rope = position_ids;
    if (pos_shape.size() == 2) {
        auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
        pos_ids_for_rope = pos_narrowed->view({pos_shape[1]});
    } else if (pos_shape.size() == 1) {
        pos_ids_for_rope = position_ids;
    } else {
        throw std::runtime_error("Unexpected position_ids shape");
    }

    // 4. Apply RoPE to Q and K
    rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_q_head, head_dim]
    rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim]

    //  5. Prepare KV caches
    //  Ensure contiguous after permute for F16 compatibility with cache operations
    auto [k_total, v_total] = paged_kv_cache->update(layer_idx_,
                                                     k_reshaped,
                                                     v_reshaped,
                                                     slot_mapping.value());

    // 6. Compute attention
    infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
yaoht's avatar
yaoht committed
309
    
310
311
    if (is_prefill) {
        if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
yaoht's avatar
yaoht committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
            // Compute actual max sequence lengths from the cumulative seqlen tensors.
            // Passing max_position_embeddings_ here causes flash-attn's splitkv kernel to
            // compute far too many K-block iterations, reading block_table entries that do
            // not exist and then using the garbage values as KV-cache block indices,
            // resulting in an out-of-bounds GPU memory access (VMFault).
            //////////////
            // auto total_lens_cpu = total_sequence_lengths.value()->to(infinicore::Device::cpu());
            // const auto *total_lens_ptr = reinterpret_cast<const int32_t *>(total_lens_cpu->data());
            // int n_reqs = static_cast<int>(total_sequence_lengths.value()->shape()[0]);
            // int max_seqlen_k = 0;
            // for (int i = 0; i < n_reqs; ++i) {
            //     max_seqlen_k = std::max(max_seqlen_k, total_lens_ptr[i]);
            // }
            // // max_seqlen_q: with batch_size==1 the flattened seq_len equals the per-request length.
            // int max_seqlen_q = static_cast<int>(seq_len);

328
            infinicore::op::mha_varlen_(
329
330
                attn_output,
                q_reshaped,
331
332
                k_total->permute({0, 2, 1, 3}),
                v_total->permute({0, 2, 1, 3}),
333
                input_offsets.value(),
334
335
                cu_seqlens.value(),
                block_tables.value(),
yaoht's avatar
yaoht committed
336
337
                static_cast<int>(seq_len),
                max_seqlen_k,
338
339
340
                std::nullopt,
                scaling_);
        } else {
341
            infinicore::op::paged_attention_prefill_(
342
343
344
345
346
347
                attn_output,
                q_reshaped,
                k_total,
                v_total,
                block_tables.value(),
                total_sequence_lengths.value(),
348
                input_offsets.value(),
349
350
351
                std::nullopt,
                scaling_);
        }
352
    } else {
suss's avatar
suss committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
            // FA2 decode path: flash::mha_fwd_kvcache
            // In paged-attn mode, seq_len = actual batch_size (one query token per sequence).
            // q_reshaped: [seq_len, num_heads, head_dim] → [seq_len, 1, num_heads, head_dim]
            // k/v cache:  [num_blocks, num_kv_heads, block_size, head_dim]
            //           → permute {0,2,1,3} → [num_blocks, block_size, num_kv_heads, head_dim]
            auto q_for_fa = q_reshaped->view({seq_len, 1, num_attention_heads_, head_dim_});
            auto attn_out_4d = infinicore::op::mha_kvcache(
                q_for_fa,
                k_total->permute({0, 2, 1, 3}),  // [num_blocks, block_size, num_kv_heads, head_dim]
                v_total->permute({0, 2, 1, 3}),
                total_sequence_lengths.value(),  // [seq_len] int32 (one entry per sequence)
                block_tables.value(),            // [seq_len, max_num_blocks_per_seq] int32
                std::nullopt,
                scaling_);
            attn_output = attn_out_4d->view({seq_len, num_attention_heads_, head_dim_});
        } else {
            infinicore::op::paged_attention_(
                attn_output,
                q_reshaped,
                k_total,
                v_total,
                block_tables.value(),
                total_sequence_lengths.value(),
                std::nullopt,
                scaling_);
        }
380
    }
suss's avatar
suss committed
381
    
382
383

    // 7. Project output
384
385
    attn_output
        = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
386
387
388
389
390
391
    return o_proj_->forward(attn_output);
}

infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
                                           const infinicore::Tensor &position_ids,
                                           std::shared_ptr<cache::Cache> kv_cache,
392
393
                                           std::optional<infinicore::Tensor> past_sequence_lengths,
                                           std::optional<infinicore::Tensor> total_sequence_lengths,
394
                                           std::optional<infinicore::Tensor> input_offsets,
395
                                           std::optional<infinicore::Tensor> cu_seqlens,
396
                                           std::optional<infinicore::Tensor> block_tables,
yaoht's avatar
yaoht committed
397
                                           std::optional<infinicore::Tensor> slot_mapping,int max_seqlen_q, int max_seqlen_k) const {
398
399
400
401
402
403
    if (!rotary_emb_) {
        throw std::runtime_error("LlamaAttention: rotary_emb not configured");
    }

    infinicore::Tensor output;
    if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
yaoht's avatar
yaoht committed
404
        output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping,max_seqlen_q, max_seqlen_k);
405
    } else {
406
        output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths);
407
408
409
410
    }
    return output;
}

411
412
413
414
415
void LlamaAttention::set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) {
    rotary_emb_ = rotary_emb;
}

} // namespace infinilm::models::llama