llama_attention.hpp 3.21 KB
Newer Older
1
2
#pragma once

3
4
5
6
7
#include "../../cache/kv_cache.hpp"
#include "../../engine/distributed/distributed.hpp"
#include "../../layers/fused_linear.hpp"
#include "llama_config.hpp"

8
#include "infinicore/nn/linear.hpp"
Your Name's avatar
Your Name committed
9
#include "infinicore/nn/module.hpp"
10
11
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
Your Name's avatar
Your Name committed
12
#include "llama_config.hpp"
13
14
#include <algorithm>
#include <memory>
Your Name's avatar
Your Name committed
15
16
#include <utility>

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
namespace infinilm::models::llama {

/**
 * @brief Multi-head self-attention module for Llama
 *
 * Implements the attention mechanism with:
 * - Query, Key, Value projections
 * - Output projection
 * - Rotary Position Embeddings (RoPE) applied to Q and K
 * - Support for Grouped Query Attention (GQA)
 */
class LlamaAttention : public infinicore::nn::Module {
public:
    /**
     * @brief Construct LlamaAttention module
     *
     * @param config Model configuration
     * @param device Device to create tensors on
Ceng's avatar
Ceng committed
35
     * @param layer_idx Layer index for cache access
36
37
     * @param dtype Optional data type for model parameters (defaults to F32)
     */
Your Name's avatar
Your Name committed
38
39
40
41
42
    LlamaAttention(const LlamaConfig &config,
                   const infinicore::Device &device,
                   size_t layer_idx,
                   infinicore::DataType dtype = infinicore::DataType::F32,
                   engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
43
44
45
46
47
48

    /**
     * @brief Forward pass: compute attention
     *
     * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size]
     * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
Ceng's avatar
Ceng committed
49
     * @param kv_cache Optional model-level KV cache for incremental decoding
50
51
52
     * @return Output tensor of shape [batch, seq_len, hidden_size]
     */
    infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
Your Name's avatar
Your Name committed
53
54
                               const infinicore::Tensor &position_ids,
                               void *kv_cache = nullptr) const;
55

Ceng's avatar
Ceng committed
56
57
58
59
60
    /**
     * @brief Get the layer index
     */
    size_t layer_idx() const { return layer_idx_; }

61
62
63
64
65
66
67
68
69
70
71
72
73
    /**
     * @brief Provide shared RoPE module from parent model.
     */
    void set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb);

    // Module information
    size_t num_heads() const { return num_attention_heads_; }
    size_t num_kv_heads() const { return num_key_value_heads_; }
    size_t head_dim() const { return head_dim_; }
    size_t hidden_size() const { return hidden_size_; }

protected:
    // Projection layers
74
    INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
Your Name's avatar
Your Name committed
75
76
77
    INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj);

    engine::distributed::RankInfo rank_info_;
78
79
80
81
82

    // Shared Rotary Position Embeddings (RoPE)
    std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;

private:
Your Name's avatar
Your Name committed
83
    size_t layer_idx_; // Layer index for cache access
84
85
86
87
88
    size_t hidden_size_;
    size_t num_attention_heads_;
    size_t num_key_value_heads_;
    size_t head_dim_;
    size_t kv_dim_;
Your Name's avatar
Your Name committed
89
90
91
    bool use_bias_;                  // Bias for Q/K/V projections
    bool use_output_bias_;           // Bias for output projection (o_proj)
    size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
92
93
94
};

} // namespace infinilm::models::llama