llama_for_causal_lm.cpp 3.55 KB
Newer Older
1
#include "llama_for_causal_lm.hpp"
Your Name's avatar
Your Name committed
2
#include "infinicore/context/context.hpp"
3
4
5
#include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp"
namespace infinilm::models::llama {
6
7
8
9
10
11
12
13
14
15
16
17
/**
 * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
 *
 * ⚠️ DEVELOPMENT POLICY:
 *   - NO new development or feature additions permitted on this interface
 *   - Only critical bug fixes (security/stability) allowed until removal
 *   - All new code MUST migrate to the polymorphic overload below
 *
 * Replacement: Use the polymorphic overload of this same function name with updated signature
 * Reason: Legacy signature lacks support for dynamic quantization modes.
 * Removal target: v0.2.0 (Q2 2026)
 */
Your Name's avatar
Your Name committed
18
19
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
                                   const infinicore::Device &device,
20
21
                                   engine::distributed::RankInfo rank_info,
                                   backends::AttentionBackend attention_backend) {
22

Ceng's avatar
Ceng committed
23
    // Initialize module's device_ member
24
    device_ = device;
25
    const auto &dtype{config.dtype};
26
    // Initialize base model
27
    INFINICORE_NN_MODULE_INIT(model, config, device, rank_info, attention_backend);
28
29
30
31
32
33
34
35

    // Initialize language modeling head
    // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
    // For now, we create a separate linear layer
    INFINICORE_NN_MODULE_INIT(lm_head, config.hidden_size, config.vocab_size, false,
                              dtype, device);
}

36
37
LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
                                   const infinicore::Device &device,
38
39
                                   engine::distributed::RankInfo rank_info,
                                   backends::AttentionBackend attention_backend) {
40
41
42
43
44
45

    // Initialize module's device_ member
    device_ = device;
    const auto &dtype{model_config->get_dtype()};

    // Initialize base model
46
    INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info, attention_backend);
47
48
49
50
51
52
53
54
    // Initialize language modeling head
    // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
    // For now, we create a separate linear layer

    INFINICORE_NN_MODULE_INIT(lm_head, model_config->get<size_t>("hidden_size"), model_config->get<size_t>("vocab_size"), false,
                              dtype, device);
}

55
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
56
57
    auto input_ids = input.input_ids.value();
    auto position_ids = input.position_ids.value();
58
59
    auto past_sequence_lengths = input.past_sequence_lengths;
    auto total_sequence_length = input.total_sequence_lengths;
60
    auto input_offsets = input.input_offsets;
61
    auto cu_seqlens = input.cu_seqlens;
62
63
    auto block_tables = input.block_tables;
    auto slot_mapping = input.slot_mapping;
64

65
    // 1. Forward through base model to get hidden states
66
    auto hidden_states = model_->forward(
67
        input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, cu_seqlens, block_tables, slot_mapping);
68
69
70

    // 2. Apply language modeling head to get logits
    auto logits = lm_head_->forward(hidden_states);
71
    return {logits};
72
73
}

PanZezhong's avatar
PanZezhong committed
74
void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) {
75
76
77
78
79
80
    cache_config_ = cache_config->unique_copy();
    model_->reset_cache(cache_config_.get());
}

const cache::CacheConfig *LlamaForCausalLM::get_cache_config() const {
    return cache_config_.get();
81
82
}

83
} // namespace infinilm::models::llama