llama_for_causal_lm.cpp 2.46 KB
Newer Older
1
2
3
#include "llama_for_causal_lm.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp"
Ceng's avatar
Ceng committed
4
5
#include "infinicore/context/context.hpp"
#include <iostream>
6
7
8
9
10

namespace infinilm::models::llama {

LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device,
                                   infinicore::DataType dtype) {
11

Ceng's avatar
Ceng committed
12
    // Initialize module's device_ member
13
    device_ = device;
Ceng's avatar
Ceng committed
14

15
16
17
18
19
20
21
22
23
24
25
    // Initialize base model
    INFINICORE_NN_MODULE_INIT(model, config, device, dtype);

    // Initialize language modeling head
    // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
    // For now, we create a separate linear layer
    INFINICORE_NN_MODULE_INIT(lm_head, config.hidden_size, config.vocab_size, false,
                              dtype, device);
}

infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids,
26
                                             const infinicore::Tensor &position_ids,
Ceng's avatar
Ceng committed
27
                                             void *kv_cache) const {
28
    // 1. Forward through base model to get hidden states
29
    auto position_ids_device = position_ids->to(device_);
Ceng's avatar
Ceng committed
30
    auto hidden_states = model_->forward(input_ids, position_ids_device, kv_cache);
31
32
33
34

    // 2. Apply language modeling head to get logits
    auto logits = lm_head_->forward(hidden_states);

Ceng's avatar
Ceng committed
35
36
37
38
39
40
41
    // 3. CRITICAL: Synchronize the C++ backend's context after forward pass
    // This ensures all C++ backend operations complete before returning to Python
    if (device_.getType() != infinicore::Device::Type::CPU) {
        infinicore::context::setDevice(device_, false);
        infinicore::context::syncStream();
    }

42
43
44
    return logits;
}

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
infinicore::Tensor LlamaForCausalLM::forward(std::vector<std::any> args) const {
    if (args.size() < 2) {
        throw std::invalid_argument("LlamaForCausalLM::forward requires at least 2 arguments: input_ids and position_ids");
    }

    // Extract input tensors from args
    const auto &input_ids = std::any_cast<const infinicore::Tensor &>(args[0]);
    const auto &position_ids = std::any_cast<const infinicore::Tensor &>(args[1]);

    // Optional KV caches
    std::vector<void *> *kv_caches = nullptr;
    if (args.size() >= 3) {
        kv_caches = std::any_cast<std::vector<void *> *>(args[2]);
    }

    return forward(input_ids, position_ids, kv_caches);
}

Ceng's avatar
Ceng committed
63
64
65
66
void LlamaForCausalLM::reset_cache(size_t pos) {
    model_->reset_cache(pos);
}

67
} // namespace infinilm::models::llama