llama_for_causal_lm.cpp 1.83 KB
Newer Older
1
#include "llama_for_causal_lm.hpp"
Your Name's avatar
Your Name committed
2
#include "infinicore/context/context.hpp"
3
4
#include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp"
Ceng's avatar
Ceng committed
5
#include <iostream>
6
7
8

namespace infinilm::models::llama {

Your Name's avatar
Your Name committed
9
10
11
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
                                   const infinicore::Device &device,
                                   engine::distributed::RankInfo rank_info) {
12

Ceng's avatar
Ceng committed
13
    // Initialize module's device_ member
14
    device_ = device;
Ceng's avatar
Ceng committed
15

16
17
    const auto &dtype{config.dtype};

18
    // Initialize base model
19
    INFINICORE_NN_MODULE_INIT(model, config, device, rank_info);
20
21
22
23
24
25
26
27

    // Initialize language modeling head
    // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
    // For now, we create a separate linear layer
    INFINICORE_NN_MODULE_INIT(lm_head, config.hidden_size, config.vocab_size, false,
                              dtype, device);
}

28
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
29
30
    auto input_ids = input.input_ids.value();
    auto position_ids = input.position_ids.value();
31
32
    auto past_sequence_lengths = input.past_sequence_lengths;
    auto total_sequence_length = input.total_sequence_lengths;
33
34
35
    auto input_offsets = input.input_offsets;
    auto block_tables = input.block_tables;
    auto slot_mapping = input.slot_mapping;
36

37
    // 1. Forward through base model to get hidden states
38
39
    auto hidden_states = model_->forward(
        input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, block_tables, slot_mapping);
40
41
42
43

    // 2. Apply language modeling head to get logits
    auto logits = lm_head_->forward(hidden_states);

44
    return {logits};
45
46
}

PanZezhong's avatar
PanZezhong committed
47
48
void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) {
    model_->reset_cache(cache_config);
49
50
}

51
} // namespace infinilm::models::llama