llama_model.cpp 4.29 KB
Newer Older
1
2
3
4
5
#include "llama_model.hpp"
#include "infinicore/nn/embedding.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
6
#include <iostream>
7
8
9

namespace infinilm::models::llama {

Your Name's avatar
Your Name committed
10
11
12
LlamaModel::LlamaModel(const LlamaConfig &config,
                       const infinicore::Device &device,
                       engine::distributed::RankInfo rank_info)
PanZezhong's avatar
PanZezhong committed
13
    : config_(config), rank_info_(rank_info) {
14
    const auto &dtype{config.dtype};
15
16
17
18
    // Initialize token embeddings
    INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size,
                              std::nullopt, dtype, device);

Ceng's avatar
Ceng committed
19
20
21
22
23
24
25
    // Initialize decoder layers with layer indices
    // TODO: Update INFINICORE_NN_MODULE_VEC_INIT macro to support per-layer constructor arguments
    //       (e.g., via a factory function or lambda that receives the layer index)
    //       Currently, we can't use the macro because each layer needs a different layer_idx
    layers_.reserve(config.num_hidden_layers);
    for (size_t i = 0; i < config.num_hidden_layers; ++i) {
        layers_.push_back(this->register_module<LlamaDecoderLayer>(
26
            "layers." + std::to_string(i), config, device, i, rank_info));
Ceng's avatar
Ceng committed
27
    }
28
29
30
31
32
33
34
35
36

    // Initialize final layer normalization
    INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps,
                              dtype, device);

    // Initialize Rotary Position Embeddings (shared across all layers)
    // Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing
    INFINICORE_NN_MODULE_INIT(rotary_emb, config.head_dim, config.max_position_embeddings,
                              config.rope_theta, infinicore::nn::RoPE::Algo::GPT_NEOX,
PanZezhong's avatar
PanZezhong committed
37
                              dtype, device, config.rope_scaling);
38
39
40
41
42
43
44
45
46
47

    for (auto &layer : layers_) {
        if (layer) {
            layer->set_rotary_emb(rotary_emb_);
        }
    }
}

infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
                                       const infinicore::Tensor &position_ids,
48
49
                                       std::optional<infinicore::Tensor> past_sequence_lengths,
                                       std::optional<infinicore::Tensor> total_sequence_lengths,
50
51
52
                                       std::optional<infinicore::Tensor> input_offsets,
                                       std::optional<infinicore::Tensor> block_tables,
                                       std::optional<infinicore::Tensor> slot_mapping) const {
53
54
55
56
    // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
    auto hidden_states = embed_tokens_->forward(input_ids);

    // 2. Process through all decoder layers
Ceng's avatar
Ceng committed
57
    size_t num_layers = layers_.size();
58
    infinicore::Tensor residual;
Ceng's avatar
Ceng committed
59
    for (size_t i = 0; i < num_layers; ++i) {
60
61
62
63
64
65
66
67
68
69
        layers_.at(i)->forward(
            hidden_states,
            residual,
            position_ids,
            kv_cache_,
            past_sequence_lengths,
            total_sequence_lengths,
            input_offsets,
            block_tables,
            slot_mapping);
70
71
    }

72
73
74
    norm_->forward_inplace(hidden_states, residual);

    return hidden_states;
75
76
}

PanZezhong's avatar
PanZezhong committed
77
78
79
80
void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
    if (cache_config == nullptr) {
        kv_cache_ = nullptr;
        return;
81
    }
PanZezhong's avatar
PanZezhong committed
82
83
84
85
86
87
88
89
90
91
92
93
    if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) {
        kv_cache_ = std::make_shared<cache::StaticKVCache>(
            config_.head_dim,
            config_.head_dim,
            config_.num_key_value_heads,
            config_.num_key_value_heads,
            config_.num_hidden_layers,
            config_.max_position_embeddings,
            config_.dtype,
            *kv_cache_config,
            rank_info_);

94
95
96
97
98
99
100
101
102
103
    } else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) {
        kv_cache_ = std::make_shared<cache::PagedKVCache>(
            config_.head_dim,
            config_.head_dim,
            config_.num_key_value_heads,
            config_.num_key_value_heads,
            config_.num_hidden_layers,
            config_.dtype,
            *paged_kv_cache_config,
            rank_info_);
PanZezhong's avatar
PanZezhong committed
104
105
    } else {
        throw std::runtime_error("Unsupported cache type");
Ceng's avatar
Ceng committed
106
107
108
    }
}

109
} // namespace infinilm::models::llama