llama_attention.hpp 5.83 KB
Newer Older
1
2
#pragma once

3
#include "../../cache/kv_cache.hpp"
4
#include "../../config/model_config.hpp"
5
6
7
8
#include "../../engine/distributed/distributed.hpp"
#include "../../layers/fused_linear.hpp"
#include "llama_config.hpp"

9
#include "infinicore/nn/linear.hpp"
Your Name's avatar
Your Name committed
10
#include "infinicore/nn/module.hpp"
wangpengcheng's avatar
wangpengcheng committed
11
#include "infinicore/nn/rmsnorm.hpp"
12
13
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
Your Name's avatar
Your Name committed
14
#include "llama_config.hpp"
15
16
#include <algorithm>
#include <memory>
Your Name's avatar
Your Name committed
17
18
#include <utility>

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
namespace infinilm::models::llama {

/**
 * @brief Multi-head self-attention module for Llama
 *
 * Implements the attention mechanism with:
 * - Query, Key, Value projections
 * - Output projection
 * - Rotary Position Embeddings (RoPE) applied to Q and K
 * - Support for Grouped Query Attention (GQA)
 */
class LlamaAttention : public infinicore::nn::Module {
public:
    /**
     * @brief Construct LlamaAttention module
     *
     * @param config Model configuration
     * @param device Device to create tensors on
Ceng's avatar
Ceng committed
37
     * @param layer_idx Layer index for cache access
38
39
     * @param dtype Optional data type for model parameters (defaults to F32)
     */
40
41
42
43
44
45
46
47
48
49
50
51
    /**
     * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
     *
     * ⚠️ DEVELOPMENT POLICY:
     *   - NO new development or feature additions permitted on this interface
     *   - Only critical bug fixes (security/stability) allowed until removal
     *   - All new code MUST migrate to the polymorphic overload below
     *
     * Replacement: Use the polymorphic overload of this same function name with updated signature
     * Reason: Legacy signature lacks support for dynamic quantization modes.
     * Removal target: v0.2.0 (Q2 2026)
     */
Your Name's avatar
Your Name committed
52
53
54
55
    LlamaAttention(const LlamaConfig &config,
                   const infinicore::Device &device,
                   size_t layer_idx,
                   engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
56

57
58
59
60
61
    LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
                   const infinicore::Device &device,
                   size_t layer_idx,
                   engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

62
63
64
65
66
    /**
     * @brief Forward pass: compute attention
     *
     * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size]
     * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
Ceng's avatar
Ceng committed
67
     * @param kv_cache Optional model-level KV cache for incremental decoding
68
69
70
     * @return Output tensor of shape [batch, seq_len, hidden_size]
     */
    infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
Your Name's avatar
Your Name committed
71
                               const infinicore::Tensor &position_ids,
PanZezhong's avatar
PanZezhong committed
72
                               std::shared_ptr<infinilm::cache::Cache> kv_cache,
73
74
                               std::optional<infinicore::Tensor> past_sequence_lengths,
                               std::optional<infinicore::Tensor> total_sequence_lengths,
75
76
                               std::optional<infinicore::Tensor> input_offsets,
                               std::optional<infinicore::Tensor> block_tables,
77
                               std::optional<infinicore::Tensor> slot_mapping) const;
78

Ceng's avatar
Ceng committed
79
80
81
82
83
    /**
     * @brief Get the layer index
     */
    size_t layer_idx() const { return layer_idx_; }

84
85
86
87
88
89
90
91
92
93
94
    /**
     * @brief Provide shared RoPE module from parent model.
     */
    void set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb);

    // Module information
    size_t num_heads() const { return num_attention_heads_; }
    size_t num_kv_heads() const { return num_key_value_heads_; }
    size_t head_dim() const { return head_dim_; }
    size_t hidden_size() const { return hidden_size_; }

95
96
97
98
private:
    infinicore::Tensor forward_(const infinicore::Tensor &hidden_states,
                                const infinicore::Tensor &position_ids,
                                std::shared_ptr<infinilm::cache::Cache> kv_cache,
99
100
                                std::optional<infinicore::Tensor> past_sequence_lengths,
                                std::optional<infinicore::Tensor> total_sequence_lengths) const;
101
102
103
104

    infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states,
                                      const infinicore::Tensor &position_ids,
                                      std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
105
                                      std::optional<infinicore::Tensor> total_sequence_lengths,
106
107
108
109
                                      std::optional<infinicore::Tensor> input_offsets,
                                      std::optional<infinicore::Tensor> block_tables,
                                      std::optional<infinicore::Tensor> slot_mapping) const;

110
111
protected:
    // Projection layers
112
    INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
Your Name's avatar
Your Name committed
113
    INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj);
wangpengcheng's avatar
wangpengcheng committed
114
115
    INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm);
    INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm);
Your Name's avatar
Your Name committed
116
    engine::distributed::RankInfo rank_info_;
117
118
119
120
121

    // Shared Rotary Position Embeddings (RoPE)
    std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;

private:
122
    std::shared_ptr<infinilm::config::ModelConfig> model_config_ = std::make_shared<infinilm::config::ModelConfig>();
Your Name's avatar
Your Name committed
123
    size_t layer_idx_; // Layer index for cache access
124
125
126
127
128
    size_t hidden_size_;
    size_t num_attention_heads_;
    size_t num_key_value_heads_;
    size_t head_dim_;
    size_t kv_dim_;
Your Name's avatar
Your Name committed
129
130
    bool use_bias_;                  // Bias for Q/K/V projections
    bool use_output_bias_;           // Bias for output projection (o_proj)
131
    bool use_qk_norm_ = false;       // Whether to use QK RMSNorm
Your Name's avatar
Your Name committed
132
    size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
133
134

    float scaling_;
135
136
137
};

} // namespace infinilm::models::llama