llama_attention.hpp 6.29 KB
Newer Older
1
2
#pragma once

3
#include "../../backends/attention_backends.hpp"
4
#include "../../cache/kv_cache.hpp"
5
#include "../../config/model_config.hpp"
6
7
8
9
#include "../../engine/distributed/distributed.hpp"
#include "../../layers/fused_linear.hpp"
#include "llama_config.hpp"

10
#include "infinicore/nn/linear.hpp"
Your Name's avatar
Your Name committed
11
#include "infinicore/nn/module.hpp"
wangpengcheng's avatar
wangpengcheng committed
12
#include "infinicore/nn/rmsnorm.hpp"
13
14
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
Your Name's avatar
Your Name committed
15
#include "llama_config.hpp"
16
17
#include <algorithm>
#include <memory>
Your Name's avatar
Your Name committed
18
19
#include <utility>

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
namespace infinilm::models::llama {

/**
 * @brief Multi-head self-attention module for Llama
 *
 * Implements the attention mechanism with:
 * - Query, Key, Value projections
 * - Output projection
 * - Rotary Position Embeddings (RoPE) applied to Q and K
 * - Support for Grouped Query Attention (GQA)
 */
class LlamaAttention : public infinicore::nn::Module {
public:
    /**
     * @brief Construct LlamaAttention module
     *
     * @param config Model configuration
     * @param device Device to create tensors on
Ceng's avatar
Ceng committed
38
     * @param layer_idx Layer index for cache access
39
40
     * @param dtype Optional data type for model parameters (defaults to F32)
     */
41
42
43
44
45
46
47
48
49
50
51
52
    /**
     * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
     *
     * ⚠️ DEVELOPMENT POLICY:
     *   - NO new development or feature additions permitted on this interface
     *   - Only critical bug fixes (security/stability) allowed until removal
     *   - All new code MUST migrate to the polymorphic overload below
     *
     * Replacement: Use the polymorphic overload of this same function name with updated signature
     * Reason: Legacy signature lacks support for dynamic quantization modes.
     * Removal target: v0.2.0 (Q2 2026)
     */
Your Name's avatar
Your Name committed
53
54
55
    LlamaAttention(const LlamaConfig &config,
                   const infinicore::Device &device,
                   size_t layer_idx,
56
57
                   engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
                   backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
58

59
60
61
    LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
                   const infinicore::Device &device,
                   size_t layer_idx,
62
63
                   engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
                   backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
64

65
66
67
68
69
    /**
     * @brief Forward pass: compute attention
     *
     * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size]
     * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
Ceng's avatar
Ceng committed
70
     * @param kv_cache Optional model-level KV cache for incremental decoding
71
72
73
     * @return Output tensor of shape [batch, seq_len, hidden_size]
     */
    infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
Your Name's avatar
Your Name committed
74
                               const infinicore::Tensor &position_ids,
PanZezhong's avatar
PanZezhong committed
75
                               std::shared_ptr<infinilm::cache::Cache> kv_cache,
76
77
                               std::optional<infinicore::Tensor> past_sequence_lengths,
                               std::optional<infinicore::Tensor> total_sequence_lengths,
78
                               std::optional<infinicore::Tensor> input_offsets,
79
                               std::optional<infinicore::Tensor> cu_seqlens,
80
                               std::optional<infinicore::Tensor> block_tables,
81
                               std::optional<infinicore::Tensor> slot_mapping) const;
82

Ceng's avatar
Ceng committed
83
84
85
86
87
    /**
     * @brief Get the layer index
     */
    size_t layer_idx() const { return layer_idx_; }

88
89
90
91
92
93
94
95
96
97
98
    /**
     * @brief Provide shared RoPE module from parent model.
     */
    void set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb);

    // Module information
    size_t num_heads() const { return num_attention_heads_; }
    size_t num_kv_heads() const { return num_key_value_heads_; }
    size_t head_dim() const { return head_dim_; }
    size_t hidden_size() const { return hidden_size_; }

99
100
101
102
private:
    infinicore::Tensor forward_(const infinicore::Tensor &hidden_states,
                                const infinicore::Tensor &position_ids,
                                std::shared_ptr<infinilm::cache::Cache> kv_cache,
103
104
                                std::optional<infinicore::Tensor> past_sequence_lengths,
                                std::optional<infinicore::Tensor> total_sequence_lengths) const;
105
106
107
108

    infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states,
                                      const infinicore::Tensor &position_ids,
                                      std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
109
                                      std::optional<infinicore::Tensor> total_sequence_lengths,
110
                                      std::optional<infinicore::Tensor> input_offsets,
111
                                      std::optional<infinicore::Tensor> cu_seqlens,
112
113
114
                                      std::optional<infinicore::Tensor> block_tables,
                                      std::optional<infinicore::Tensor> slot_mapping) const;

115
116
protected:
    // Projection layers
117
    INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
Your Name's avatar
Your Name committed
118
    INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj);
wangpengcheng's avatar
wangpengcheng committed
119
120
    INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm);
    INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm);
Your Name's avatar
Your Name committed
121
    engine::distributed::RankInfo rank_info_;
122
123
124
125
126

    // Shared Rotary Position Embeddings (RoPE)
    std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;

private:
127
    std::shared_ptr<infinilm::config::ModelConfig> model_config_ = std::make_shared<infinilm::config::ModelConfig>();
Your Name's avatar
Your Name committed
128
    size_t layer_idx_; // Layer index for cache access
129
130
131
132
133
    size_t hidden_size_;
    size_t num_attention_heads_;
    size_t num_key_value_heads_;
    size_t head_dim_;
    size_t kv_dim_;
Your Name's avatar
Your Name committed
134
135
    bool use_bias_;                  // Bias for Q/K/V projections
    bool use_output_bias_;           // Bias for output projection (o_proj)
136
    bool use_qk_norm_ = false;       // Whether to use QK RMSNorm
Your Name's avatar
Your Name committed
137
    size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
138
139

    float scaling_;
140
141

    backends::AttentionBackend attention_backend_;
142
143
144
};

} // namespace infinilm::models::llama