llama_attention.cpp 6.94 KB
Newer Older
1
2
3
4
5
#include "llama_attention.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/mul.hpp"
6
#include <algorithm>
7
8
9
#include <cmath>
#include <cstring>
#include <iostream>
10
11
#include <spdlog/spdlog.h>
#include <stdexcept>
Your Name's avatar
Your Name committed
12
#include <vector>
13
14
15

namespace infinilm::models::llama {

Your Name's avatar
Your Name committed
16
17
LlamaAttention::LlamaAttention(const LlamaConfig &config,
                               const infinicore::Device &device,
Ceng's avatar
Ceng committed
18
                               size_t layer_idx,
Your Name's avatar
Your Name committed
19
20
                               infinicore::DataType dtype,
                               engine::distributed::RankInfo rank_info)
Ceng's avatar
Ceng committed
21
22
    : layer_idx_(layer_idx),
      hidden_size_(config.hidden_size),
23
24
25
26
      num_attention_heads_(config.num_attention_heads),
      num_key_value_heads_(config.num_key_value_heads),
      head_dim_(config.head_dim),
      kv_dim_(config.kv_dim()),
Ceng's avatar
Ceng committed
27
28
      use_bias_(config.attention_bias),
      use_output_bias_(config.attention_output_bias),
Your Name's avatar
Your Name committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
      max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) {

    int tp_rank = rank_info.tp_rank;
    int tp_size = rank_info.tp_size;

    int num_attention_heads = config.num_attention_heads;
    int num_key_value_heads = config.num_key_value_heads;

    if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) {
        this->num_attention_heads_ = num_attention_heads / tp_size;
        this->num_key_value_heads_ = num_key_value_heads / tp_size;
    } else {
        throw std::runtime_error("num_attention_heads / tp_size error.");
    }

44
    // Initialize projection layers
45
46
    INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_,
                             dtype, device, rank_info);
Ceng's avatar
Ceng committed
47
48
    // Output projection uses attention_output_bias (can be different from qkv)
    INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_,
Your Name's avatar
Your Name committed
49
                              dtype, device, tp_rank, tp_size, rank_info.comm);
50
51
52
}

infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
Your Name's avatar
Your Name committed
53
54
                                           const infinicore::Tensor &position_ids,
                                           void *kv_cache) const {
55
56
57
58
59
60
61
62
63
64
    if (!rotary_emb_) {
        throw std::runtime_error("LlamaAttention: rotary_emb not configured");
    }
    // Input shape: [batch, seq_len, hidden_size]
    auto hidden_states_mutable = hidden_states;
    auto shape = hidden_states->shape();
    size_t batch_size = shape[0];
    size_t seq_len = shape[1];

    // 1. Project Q, K, V
65
    auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    // 2. Reshape for multi-head attention

    // Reshape Q, K, V to include batch dimension
    // Python: query_states = self.q_proj(hidden_states).view(querys_shape)
    // The view operation requires the tensor to be contiguous in the required dimensions
    auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_});
    auto k_reshaped = k->view({batch_size, seq_len, num_key_value_heads_, head_dim_});
    auto v_reshaped = v->view({batch_size, seq_len, num_key_value_heads_, head_dim_});

    // 3. Prepare position_ids for RoPE - align with Python pattern
    // Python: bs, num = pos_ids.shape; pos_ids = pos_ids.view((bs * num,))
    auto pos_shape = position_ids->shape();
    infinicore::Tensor pos_ids_for_rope = position_ids;
    if (pos_shape.size() == 2) {
        auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
        pos_ids_for_rope = pos_narrowed->contiguous()->view({pos_shape[1]});
    } else if (pos_shape.size() == 1) {
        pos_ids_for_rope = position_ids->contiguous();
    } else {
        throw std::runtime_error("Unexpected position_ids shape");
    }

Ceng's avatar
Ceng committed
89
    // 4. Prepare KV caches
90
91
    // Convert to [batch, n_head, seq_len, head_dim] for cache
    // Ensure contiguous after permute for F16 compatibility with cache operations
PanZezhong's avatar
PanZezhong committed
92
93
94
    q_reshaped = q_reshaped->permute({0, 2, 1, 3})->contiguous(); // [bs, n_q_head, seq_len, head_dim]
    auto k_permuted = k_reshaped->permute({0, 2, 1, 3});          // [bs, n_kv_head, seq_len, head_dim]
    auto v_permuted = v_reshaped->permute({0, 2, 1, 3});          // [bs, n_kv_head, seq_len, head_dim]
Ceng's avatar
Ceng committed
95
    infinilm::cache::DynamicCache *external_cache = static_cast<infinilm::cache::DynamicCache *>(kv_cache);
96
97
98
    infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
    infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
    if (external_cache != nullptr) {
Ceng's avatar
Ceng committed
99
        auto [k_total_tmp, v_total_tmp] = external_cache->update(layer_idx_, k_permuted, v_permuted);
100
101
102
        k_total = k_total_tmp;
        v_total = v_total_tmp;
    } else {
Ceng's avatar
Ceng committed
103
104
        // No external cache - this shouldn't happen in normal operation, but handle gracefully
        throw std::runtime_error("LlamaAttention: kv_cache is required but nullptr provided");
105
    }
106
107
    auto total_seq_len = k_total->shape()[2];

PanZezhong's avatar
PanZezhong committed
108
109
110
111
112
113
    // 5. Apply RoPE to full batch
    auto q_rope = q_reshaped->view({batch_size * num_attention_heads_, seq_len, head_dim_})->permute({1, 0, 2});                                               // [seq_len, bs * n_q_head, head_dim]
    auto k_rope = k_total->narrow({{2, total_seq_len - seq_len, seq_len}})->view({batch_size * num_key_value_heads_, seq_len, head_dim_})->permute({1, 0, 2}); // [seq_len, bs * n_kv_head, head_dim]
    rotary_emb_->forward(q_rope, pos_ids_for_rope, true);
    rotary_emb_->forward(k_rope, pos_ids_for_rope, true);

114
115
    // 6. Compute attention
    size_t ngroup = num_attention_heads_ / num_key_value_heads_;
PanZezhong's avatar
PanZezhong committed
116
    auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
    auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});

    auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]

    float scaling = 1.0f / std::sqrt(static_cast<float>(head_dim_));
    auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [bs * n_kv_head, ng * seq_len, total_seq_len]

    auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
    infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);

    auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]

    auto attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
                           ->permute({0, 2, 1, 3})
                           ->contiguous()
                           ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
134

135
    auto output = o_proj_->forward(attn_output);
136
137
138
139
140
141
142
143
144

    return output;
}

void LlamaAttention::set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) {
    rotary_emb_ = rotary_emb;
}

} // namespace infinilm::models::llama