llama_model.cpp 4.09 KB
Newer Older
1
2
3
4
5
6
7
8
#include "llama_model.hpp"
#include "infinicore/nn/embedding.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"

namespace infinilm::models::llama {

Your Name's avatar
Your Name committed
9
10
11
12
LlamaModel::LlamaModel(const LlamaConfig &config,
                       const infinicore::Device &device,
                       infinicore::DataType dtype,
                       engine::distributed::RankInfo rank_info)
13
14
15
16
17
    : config_(config) {
    // Initialize token embeddings
    INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size,
                              std::nullopt, dtype, device);

Ceng's avatar
Ceng committed
18
19
20
21
22
23
24
    // Initialize decoder layers with layer indices
    // TODO: Update INFINICORE_NN_MODULE_VEC_INIT macro to support per-layer constructor arguments
    //       (e.g., via a factory function or lambda that receives the layer index)
    //       Currently, we can't use the macro because each layer needs a different layer_idx
    layers_.reserve(config.num_hidden_layers);
    for (size_t i = 0; i < config.num_hidden_layers; ++i) {
        layers_.push_back(this->register_module<LlamaDecoderLayer>(
Your Name's avatar
Your Name committed
25
            "layers." + std::to_string(i), config, device, i, dtype, rank_info));
Ceng's avatar
Ceng committed
26
    }
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    // Initialize final layer normalization
    INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps,
                              dtype, device);

    // Initialize Rotary Position Embeddings (shared across all layers)
    // Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing
    INFINICORE_NN_MODULE_INIT(rotary_emb, config.head_dim, config.max_position_embeddings,
                              config.rope_theta, infinicore::nn::RoPE::Algo::GPT_NEOX,
                              dtype, device);

    for (auto &layer : layers_) {
        if (layer) {
            layer->set_rotary_emb(rotary_emb_);
        }
    }
}

infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
                                       const infinicore::Tensor &position_ids,
Ceng's avatar
Ceng committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
                                       void *kv_cache) const {
    // Use persistent internal cache if no external cache is provided
    // This matches Python backend behavior: if use_cache and past_key_values is None, create DynamicCache
    // The cache persists across forward calls to enable incremental decoding
    void *cache_to_use = kv_cache;

    if (kv_cache == nullptr) {
        // Create or reuse persistent internal cache at model level
        // This ensures the cache persists across multiple forward calls (prefill -> decode -> decode...)
        size_t seq_len = input_ids->shape()[1];

        if (!cache_) {
            // First time: create cache
            cache_ = std::make_unique<infinilm::cache::DynamicCache>(
                config_.num_hidden_layers,
Your Name's avatar
Your Name committed
62
                config_.max_position_embeddings);
Ceng's avatar
Ceng committed
63
64
65
66
        }
        cache_to_use = cache_.get();
    }

67
68
69
70
    // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
    auto hidden_states = embed_tokens_->forward(input_ids);

    // 2. Process through all decoder layers
Ceng's avatar
Ceng committed
71
72
73
74
75
76
77
    size_t num_layers = layers_.size();
    for (size_t i = 0; i < num_layers; ++i) {
        // Pass model-level cache (layer index is now a property of the layer)
        hidden_states = layers_.at(i)->forward(hidden_states, position_ids, cache_to_use);

        // DEBUG: Disabled previous final layer logging
        // Logging moved to decoder layer for post-attention normalization
78
79
80
81
82
83
84
    }

    // 3. Apply final layer normalization to last token only (aligns with transformers)

    // Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
    auto shape = hidden_states->shape();
    size_t seq_len = shape[1];
Ceng's avatar
Ceng committed
85
    auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}});
86

Ceng's avatar
Ceng committed
87
88
89
    // DEBUG: Disabled previous final layer normalization logging
    // Normalize only the last token (matches Python backend)
    auto normalized_last_token = norm_->forward(last_token);
90
91
92
93

    return normalized_last_token;
}

Ceng's avatar
Ceng committed
94
95
96
97
98
99
void LlamaModel::reset_cache(size_t pos) const {
    if (cache_) {
        cache_->reset(pos);
    }
}

100
} // namespace infinilm::models::llama