infer_engine.hpp 3.14 KB
Newer Older
1
2
#pragma once

3
#include "../config/model_config.hpp"
4
#include "../models/infinilm_model.hpp"
Jiacheng Huang's avatar
Jiacheng Huang committed
5
#include "../models/llama/llama_config.hpp"
6
7
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
8
#include "rank_barrier.hpp"
9
10
#include "rank_worker.hpp"

11
#include <optional>
12
13
14
15
16
17
#include <vector>

namespace infinilm::engine {

class InferEngine {
public:
18
    using Input = RankWorker::Input;
PanZezhong's avatar
PanZezhong committed
19

20
    using Output = RankWorker::Output;
21

22
    // Updated constructor: accept CacheConfig instead of CacheType
23
24
25
26
27
28
29
30
31
32
33
34
    /**
     * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
     *
     * ⚠️ DEVELOPMENT POLICY:
     *   - NO new development or feature additions permitted on this interface
     *   - Only critical bug fixes (security/stability) allowed until removal
     *   - All new code MUST migrate to the polymorphic overload below
     *
     * Replacement: Use the polymorphic overload of this same function name with updated signature
     * Reason: Legacy signature lacks support for dynamic quantization modes.
     * Removal target: v0.2.0 (Q2 2026)
     */
35
    InferEngine(
Jiacheng Huang's avatar
Jiacheng Huang committed
36
        const InfinilmModel::Config &config,
37
        const distributed::DistConfig &distributed_config = distributed::DistConfig(),
38
        infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
39
        const cache::CacheConfig *cache_config = nullptr,
40
41
        bool enable_graph_compiling = false,
        backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
42

43
44
45
46
47
    InferEngine(
        const std::string &model_path = "",
        const distributed::DistConfig &distributed_config = distributed::DistConfig(),
        infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
        const cache::CacheConfig *cache_config = nullptr,
48
49
        bool enable_graph_compiling = false,
        backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
50

51
52
53
    // Load a parameter to all workers (each can extract its shard inside RankWorker)
    void load_param(const std::string &name, const infinicore::Tensor &param);

54
    // return the parameters (i.e. weights and biases).
55
    std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> state_dict();
56

57
    // Run a single forward pass on all workers and return the outputs from all ranks
58
    Output forward(const Input &input);
59

60
61
    void compile();

PanZezhong's avatar
PanZezhong committed
62
    void reset_cache(const cache::CacheConfig *new_config);
Ceng's avatar
Ceng committed
63

64
65
66
67
    ~InferEngine();

    const distributed::DistConfig &get_dist_config() const;

68
    // Get current KV configuration
PanZezhong's avatar
PanZezhong committed
69
    const cache::CacheConfig *get_cache_config() const { return cache_config_.get(); }
70

71
72
protected:
    std::vector<std::unique_ptr<RankWorker>> workers_;
73
    std::unique_ptr<RankBarrier> barrier_;
74
    distributed::CommunicationGroup communication_group_;
PanZezhong's avatar
PanZezhong committed
75
    std::unique_ptr<cache::CacheConfig> cache_config_;
76
77
    const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
    std::shared_ptr<infinilm::config::ModelConfig> model_config_;
78
    backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default;
79
80
81
};

} // namespace infinilm::engine