llama_for_causal_lm.hpp 2.3 KB
Newer Older
1
2
#pragma once

3
#include "../infinilm_model.hpp"
4
#include "llama_model.hpp"
5
6

#include "infinicore/device.hpp"
7
#include "infinicore/nn/linear.hpp"
8
#include "infinicore/nn/module.hpp"
9
10
#include "infinicore/tensor.hpp"

Your Name's avatar
Your Name committed
11
12
#include "../../engine/distributed/distributed.hpp"

13
14
15
16
17
18
19
20
21
22
namespace infinilm::models::llama {

/**
 * @brief Llama model for Causal Language Modeling
 *
 * Extends LlamaModel by adding a language modeling head (lm_head) that
 * projects hidden states to vocabulary logits.
 *
 * This matches the structure of HuggingFace's LlamaForCausalLM.
 */
23
class LlamaForCausalLM : public InfinilmModel {
24
25
26
27
28
29
public:
    /**
     * @brief Construct LlamaForCausalLM module
     *
     * @param config Model configuration
     * @param device Device to create tensors on
30
     * @param dtype Optional data type for model parameters (defaults to BF16)
31
     */
Your Name's avatar
Your Name committed
32
33
34
35
    LlamaForCausalLM(const LlamaConfig &config,
                     const infinicore::Device &device,
                     infinicore::DataType dtype = infinicore::DataType::BF16,
                     engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
36
37
38
39
40
41

    /**
     * @brief Forward pass: compute language modeling logits
     *
     * @param input_ids Token IDs tensor of shape [batch, seq_len]
     * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
Ceng's avatar
Ceng committed
42
     * @param kv_cache Optional model-level KV cache for incremental decoding
43
44
45
     * @return Logits tensor of shape [batch, seq_len, vocab_size]
     */
    infinicore::Tensor forward(const infinicore::Tensor &input_ids,
46
                               const infinicore::Tensor &position_ids,
Ceng's avatar
Ceng committed
47
                               void *kv_cache = nullptr) const;
48
49

    infinicore::Tensor forward(std::vector<std::any> args) const override;
50

Ceng's avatar
Ceng committed
51
52
    // Reset internal cache position
    void reset_cache(size_t pos = 0) override;
53
    void reset_cache(const cache::CacheConfig &new_config, size_t pos) override;
Ceng's avatar
Ceng committed
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
    // Module information
    const LlamaConfig &config() const { return model_->config(); }
    LlamaModel &model() { return *model_; }
    const LlamaModel &model() const { return *model_; }

protected:
    // Base model
    INFINICORE_NN_MODULE(LlamaModel, model);

    // Language modeling head
    INFINICORE_NN_MODULE(infinicore::nn::Linear, lm_head);
};

} // namespace infinilm::models::llama