infer_engine.hpp 2.7 KB
Newer Older
1
2
#pragma once

3
#include "../models/infinilm_model.hpp"
Jiacheng Huang's avatar
Jiacheng Huang committed
4
#include "../models/llama/llama_config.hpp"
5
6
7
8
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_worker.hpp"

9
#include <optional>
10
11
12
13
14
15
#include <vector>

namespace infinilm::engine {

class InferEngine {
public:
16
    struct Input {
17
18
19
20
21
22
23
24
25
26
27
28
29
30
        /// Token IDs tensor of shape `[batch, seq_len]`.
        std::optional<infinicore::Tensor> input_ids;
        /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
        std::optional<infinicore::Tensor> position_ids;
        /// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
        std::optional<infinicore::Tensor> cache_lengths;
        /// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`.
        std::optional<infinicore::Tensor> input_lengths;
        /// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
        std::optional<infinicore::Tensor> input_offsets;
        /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
        std::optional<infinicore::Tensor> block_tables;
        /// Slot ids for each token `[seq]`. Used for paged cache.
        std::optional<infinicore::Tensor> slot_mapping;
PanZezhong's avatar
PanZezhong committed
31
32

        infinilm::InfinilmModel::Input to_model_input() const;
33
34
35
36
37
38
    };

    struct Output {
        infinicore::Tensor logits;
    };

39
    // Updated constructor: accept CacheConfig instead of CacheType
40
    InferEngine(
Jiacheng Huang's avatar
Jiacheng Huang committed
41
        const InfinilmModel::Config &config,
42
        const distributed::DistConfig &distributed_config = distributed::DistConfig(),
43
        infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
PanZezhong's avatar
PanZezhong committed
44
        const cache::CacheConfig *cache_config = nullptr);
45
46
47
48

    // Load a parameter to all workers (each can extract its shard inside RankWorker)
    void load_param(const std::string &name, const infinicore::Tensor &param);

49
    // return the parameters (i.e. weights and biases).
50
    std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> state_dict();
51

52
    // Run a single forward pass on all workers and return the outputs from all ranks
53
    Output forward(const Input &input);
54

PanZezhong's avatar
PanZezhong committed
55
    void reset_cache(const cache::CacheConfig *new_config);
Ceng's avatar
Ceng committed
56

57
58
59
60
    ~InferEngine();

    const distributed::DistConfig &get_dist_config() const;

61
    // Get current KV configuration
PanZezhong's avatar
PanZezhong committed
62
    const cache::CacheConfig *get_cache_config() const { return cache_config_.get(); }
63

64
65
66
protected:
    std::vector<std::unique_ptr<RankWorker>> workers_;
    distributed::CommunicationGroup communication_group_;
Jiacheng Huang's avatar
Jiacheng Huang committed
67
    const InfinilmModel::Config &model_config_;
PanZezhong's avatar
PanZezhong committed
68
    std::unique_ptr<cache::CacheConfig> cache_config_;
69
70
71
};

} // namespace infinilm::engine