llama_model.hpp 3.7 KB
Newer Older
1
2
#pragma once

3
#include "../../cache/kv_cache.hpp"
4
5
6
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp"

7
#include "infinicore/nn/embedding.hpp"
Your Name's avatar
Your Name committed
8
#include "infinicore/nn/module.hpp"
9
10
11
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
Your Name's avatar
Your Name committed
12
13
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp"
Ceng's avatar
Ceng committed
14
#include <memory>
Your Name's avatar
Your Name committed
15
16
17
#include <vector>

#include "../../engine/distributed/distributed.hpp"
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

namespace infinilm::models::llama {

/**
 * @brief Main Llama model architecture (without language modeling head)
 *
 * This is the core transformer model consisting of:
 * - Token embeddings (embed_tokens)
 * - Multiple decoder layers (layers)
 * - Final layer normalization (norm)
 * - Rotary Position Embeddings (rotary_emb)
 *
 * This matches the structure of HuggingFace's LlamaModel.
 */
class LlamaModel : public infinicore::nn::Module {
public:
    /**
     * @brief Construct LlamaModel module
     *
     * @param config Model configuration
     * @param device Device to create tensors on
     * @param dtype Optional data type for model parameters (defaults to F32)
     */
Your Name's avatar
Your Name committed
41
42
43
    LlamaModel(const LlamaConfig &config,
               const infinicore::Device &device,
               engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
44
45
46
47
48
49

    /**
     * @brief Forward pass: process input through the model
     *
     * @param input_ids Token IDs tensor of shape [batch, seq_len]
     * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
Ceng's avatar
Ceng committed
50
     * @param kv_cache Optional model-level KV cache for incremental decoding
51
52
53
     * @return Output tensor of shape [batch, seq_len, hidden_size]
     */
    infinicore::Tensor forward(const infinicore::Tensor &input_ids,
Your Name's avatar
Your Name committed
54
55
                               const infinicore::Tensor &position_ids,
                               void *kv_cache = nullptr) const;
56
57
58
59
60

    // Module information
    const LlamaConfig &config() const { return config_; }
    size_t num_layers() const { return config_.num_hidden_layers; }

Ceng's avatar
Ceng committed
61
62
63
64
65
66
67
68
    /**
     * @brief Reset the internal cache to a specific position
     * This should be called when starting a new generation sequence to prevent state
     * from persisting between different questions/prompts
     * @param pos Position to reset to (defaults to 0)
     */
    void reset_cache(size_t pos = 0) const;

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    /**
     * @brief Reset the internal cache with a new configuration and position
     * This should be called when changing cache parameters (e.g., initial capacity)
     * @param new_config New cache configuration
     * @param pos Position to reset to
     */
    void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) const;

    /**
     * @brief Set external cache for the model
     * @param cache Pointer to external cache (managed by CacheManager)
     */
    void set_external_cache(std::shared_ptr<cache::DynamicCache> cache) {
        external_cache_ = cache.get();
    }

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
protected:
    // Token embeddings
    INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens);

    // Decoder layers
    INFINICORE_NN_MODULE_VEC(LlamaDecoderLayer, layers);

    // Final normalization
    INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm);

    // Rotary Position Embeddings (shared across all layers)
    INFINICORE_NN_MODULE(infinicore::nn::RoPE, rotary_emb);

private:
    LlamaConfig config_;
Ceng's avatar
Ceng committed
100
101
102
    // Persistent cache for when no external cache is provided
    // Mutable because it's not part of the model's learned parameters,
    // but needs to persist across forward calls for incremental decoding
103
104
    mutable std::unique_ptr<infinilm::cache::DynamicCache> internal_cache_;
    cache::DynamicCache *external_cache_ = nullptr;
105
106
107
};

} // namespace infinilm::models::llama