llama_attention.hpp 4.82 KB
Newer Older
1
2
#pragma once

3
4
5
6
7
#include "../../cache/kv_cache.hpp"
#include "../../engine/distributed/distributed.hpp"
#include "../../layers/fused_linear.hpp"
#include "llama_config.hpp"

8
#include "infinicore/nn/linear.hpp"
Your Name's avatar
Your Name committed
9
#include "infinicore/nn/module.hpp"
wangpengcheng's avatar
wangpengcheng committed
10
#include "infinicore/nn/rmsnorm.hpp"
11
12
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
Your Name's avatar
Your Name committed
13
#include "llama_config.hpp"
14
15
#include <algorithm>
#include <memory>
Your Name's avatar
Your Name committed
16
17
#include <utility>

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
namespace infinilm::models::llama {

/**
 * @brief Multi-head self-attention module for Llama
 *
 * Implements the attention mechanism with:
 * - Query, Key, Value projections
 * - Output projection
 * - Rotary Position Embeddings (RoPE) applied to Q and K
 * - Support for Grouped Query Attention (GQA)
 */
class LlamaAttention : public infinicore::nn::Module {
public:
    /**
     * @brief Construct LlamaAttention module
     *
     * @param config Model configuration
     * @param device Device to create tensors on
Ceng's avatar
Ceng committed
36
     * @param layer_idx Layer index for cache access
37
38
     * @param dtype Optional data type for model parameters (defaults to F32)
     */
Your Name's avatar
Your Name committed
39
40
41
42
    LlamaAttention(const LlamaConfig &config,
                   const infinicore::Device &device,
                   size_t layer_idx,
                   engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
43
44
45
46
47
48

    /**
     * @brief Forward pass: compute attention
     *
     * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size]
     * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
Ceng's avatar
Ceng committed
49
     * @param kv_cache Optional model-level KV cache for incremental decoding
50
51
52
     * @return Output tensor of shape [batch, seq_len, hidden_size]
     */
    infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
Your Name's avatar
Your Name committed
53
                               const infinicore::Tensor &position_ids,
PanZezhong's avatar
PanZezhong committed
54
                               std::shared_ptr<infinilm::cache::Cache> kv_cache,
55
56
                               std::optional<infinicore::Tensor> past_sequence_lengths,
                               std::optional<infinicore::Tensor> total_sequence_lengths,
57
58
                               std::optional<infinicore::Tensor> input_offsets,
                               std::optional<infinicore::Tensor> block_tables,
59
                               std::optional<infinicore::Tensor> slot_mapping) const;
60

Ceng's avatar
Ceng committed
61
62
63
64
65
    /**
     * @brief Get the layer index
     */
    size_t layer_idx() const { return layer_idx_; }

66
67
68
69
70
71
72
73
74
75
76
    /**
     * @brief Provide shared RoPE module from parent model.
     */
    void set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb);

    // Module information
    size_t num_heads() const { return num_attention_heads_; }
    size_t num_kv_heads() const { return num_key_value_heads_; }
    size_t head_dim() const { return head_dim_; }
    size_t hidden_size() const { return hidden_size_; }

77
78
79
80
private:
    infinicore::Tensor forward_(const infinicore::Tensor &hidden_states,
                                const infinicore::Tensor &position_ids,
                                std::shared_ptr<infinilm::cache::Cache> kv_cache,
81
82
                                std::optional<infinicore::Tensor> past_sequence_lengths,
                                std::optional<infinicore::Tensor> total_sequence_lengths) const;
83
84
85
86

    infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states,
                                      const infinicore::Tensor &position_ids,
                                      std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
87
                                      std::optional<infinicore::Tensor> total_sequence_lengths,
88
89
90
91
                                      std::optional<infinicore::Tensor> input_offsets,
                                      std::optional<infinicore::Tensor> block_tables,
                                      std::optional<infinicore::Tensor> slot_mapping) const;

92
93
protected:
    // Projection layers
94
    INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
Your Name's avatar
Your Name committed
95
    INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj);
wangpengcheng's avatar
wangpengcheng committed
96
97
    INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm);
    INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm);
Your Name's avatar
Your Name committed
98
    engine::distributed::RankInfo rank_info_;
99
100
101
102
103

    // Shared Rotary Position Embeddings (RoPE)
    std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;

private:
Your Name's avatar
Your Name committed
104
    size_t layer_idx_; // Layer index for cache access
105
106
107
108
109
    size_t hidden_size_;
    size_t num_attention_heads_;
    size_t num_key_value_heads_;
    size_t head_dim_;
    size_t kv_dim_;
Your Name's avatar
Your Name committed
110
111
    bool use_bias_;                  // Bias for Q/K/V projections
    bool use_output_bias_;           // Bias for output projection (o_proj)
wangpengcheng's avatar
wangpengcheng committed
112
    bool use_qk_norm_;               // Whether to use QK RMSNorm
Your Name's avatar
Your Name committed
113
    size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
114
115

    float scaling_;
116
117
118
};

} // namespace infinilm::models::llama