infer_engine.hpp 1.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
#pragma once

#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_worker.hpp"

#include <any>
#include <vector>

namespace infinilm::engine {

class InferEngine {
public:
14
    // Updated constructor: accept CacheConfig instead of CacheType
15
16
17
    InferEngine(
        const std::any &config,
        const distributed::DistConfig &distributed_config = distributed::DistConfig(),
18
19
        infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
        const cache::CacheConfig &cache_config = cache::CacheConfig());
20
21
22
23

    // Load a parameter to all workers (each can extract its shard inside RankWorker)
    void load_param(const std::string &name, const infinicore::Tensor &param);

24
    // return the parameters (i.e. weights and biases).
25
    std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> state_dict();
26

27
28
    // Run a single forward pass on all workers and return the outputs from all ranks
    infinicore::Tensor generate(const infinicore::Tensor &input_ids,
29
                                const infinicore::Tensor &position_ids);
30

31
32
33
34
35
    // Reset the internal cache pos in all workers (clears state between generations)
    void reset_cache(size_t pos = 0);

    // Overload: reset cache with new KV configuration
    void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0);
Ceng's avatar
Ceng committed
36

37
38
39
40
    ~InferEngine();

    const distributed::DistConfig &get_dist_config() const;

41
42
43
    // Get current KV configuration
    const cache::CacheConfig &get_cache_config() const { return cache_config_; }

44
45
46
47
protected:
    std::vector<std::unique_ptr<RankWorker>> workers_;
    distributed::CommunicationGroup communication_group_;
    std::any model_config_;
48
    cache::CacheConfig cache_config_;
49
50
51
};

} // namespace infinilm::engine