llama_for_causal_lm.hpp 2.74 KB
Newer Older
1
2
#pragma once

3
#include "../infinilm_model.hpp"
4
#include "llama_model.hpp"
5
6

#include "infinicore/device.hpp"
7
#include "infinicore/nn/linear.hpp"
8
#include "infinicore/nn/module.hpp"
9
10
#include "infinicore/tensor.hpp"

Your Name's avatar
Your Name committed
11
12
#include "../../engine/distributed/distributed.hpp"

13
14
15
16
17
18
19
20
21
22
namespace infinilm::models::llama {

/**
 * @brief Llama model for Causal Language Modeling
 *
 * Extends LlamaModel by adding a language modeling head (lm_head) that
 * projects hidden states to vocabulary logits.
 *
 * This matches the structure of HuggingFace's LlamaForCausalLM.
 */
23
class LlamaForCausalLM : public InfinilmModel {
24
25
26
27
28
29
30
public:
    /**
     * @brief Construct LlamaForCausalLM module
     *
     * @param config Model configuration
     * @param device Device to create tensors on
     */
31
32
33
34
35
36
37
38
39
40
41
42
    /**
     * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
     *
     * ⚠️ DEVELOPMENT POLICY:
     *   - NO new development or feature additions permitted on this interface
     *   - Only critical bug fixes (security/stability) allowed until removal
     *   - All new code MUST migrate to the polymorphic overload below
     *
     * Replacement: Use the polymorphic overload of this same function name with updated signature
     * Reason: Legacy signature lacks support for dynamic quantization modes.
     * Removal target: v0.2.0 (Q2 2026)
     */
Your Name's avatar
Your Name committed
43
44
    LlamaForCausalLM(const LlamaConfig &config,
                     const infinicore::Device &device,
45
46
                     engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
                     backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
47

48
49
    LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
                     const infinicore::Device &device,
50
51
                     engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
                     backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
52

53
54
55
    /**
     * @brief Forward pass: compute language modeling logits
     *
56
57
     * @param input Encapsulated input tensors and other parameters
     * @return Output structure containing the result
58
     */
59
    Output forward(const Input &input) const;
60

PanZezhong's avatar
PanZezhong committed
61
    void reset_cache(const cache::CacheConfig *cache_config) override;
Ceng's avatar
Ceng committed
62

63
64
    const cache::CacheConfig *get_cache_config() const override;

65
66
67
68
69
70
71
72
73
74
    // Module information
    LlamaModel &model() { return *model_; }
    const LlamaModel &model() const { return *model_; }

protected:
    // Base model
    INFINICORE_NN_MODULE(LlamaModel, model);

    // Language modeling head
    INFINICORE_NN_MODULE(infinicore::nn::Linear, lm_head);
75
76

    std::unique_ptr<cache::CacheConfig> cache_config_;
77
78
79
};

} // namespace infinilm::models::llama