infer_engine.hpp 1.8 KB
Newer Older
1
2
#pragma once

3
#include "../models/infinilm_model.hpp"
Jiacheng Huang's avatar
Jiacheng Huang committed
4
#include "../models/llama/llama_config.hpp"
5
6
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
7
#include "rank_barrier.hpp"
8
9
#include "rank_worker.hpp"

10
#include <optional>
11
12
13
14
15
16
#include <vector>

namespace infinilm::engine {

class InferEngine {
public:
17
    using Input = RankWorker::Input;
PanZezhong's avatar
PanZezhong committed
18

19
    using Output = RankWorker::Output;
20

21
    // Updated constructor: accept CacheConfig instead of CacheType
22
    InferEngine(
Jiacheng Huang's avatar
Jiacheng Huang committed
23
        const InfinilmModel::Config &config,
24
        const distributed::DistConfig &distributed_config = distributed::DistConfig(),
25
        infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
26
27
        const cache::CacheConfig *cache_config = nullptr,
        bool enable_graph_compiling = false);
28
29
30
31

    // Load a parameter to all workers (each can extract its shard inside RankWorker)
    void load_param(const std::string &name, const infinicore::Tensor &param);

32
    // return the parameters (i.e. weights and biases).
33
    std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> state_dict();
34

35
    // Run a single forward pass on all workers and return the outputs from all ranks
36
    Output forward(const Input &input);
37

38
39
    void compile();

PanZezhong's avatar
PanZezhong committed
40
    void reset_cache(const cache::CacheConfig *new_config);
Ceng's avatar
Ceng committed
41

42
43
44
45
    ~InferEngine();

    const distributed::DistConfig &get_dist_config() const;

46
    // Get current KV configuration
PanZezhong's avatar
PanZezhong committed
47
    const cache::CacheConfig *get_cache_config() const { return cache_config_.get(); }
48

49
50
protected:
    std::vector<std::unique_ptr<RankWorker>> workers_;
51
    std::unique_ptr<RankBarrier> barrier_;
52
    distributed::CommunicationGroup communication_group_;
Jiacheng Huang's avatar
Jiacheng Huang committed
53
    const InfinilmModel::Config &model_config_;
PanZezhong's avatar
PanZezhong committed
54
    std::unique_ptr<cache::CacheConfig> cache_config_;
55
56
57
};

} // namespace infinilm::engine