infer_engine.hpp 1.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#pragma once

#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_worker.hpp"

#include <any>
#include <vector>

namespace infinilm::engine {

class InferEngine {
public:
    InferEngine(
        const std::any &config,
        const distributed::DistConfig &distributed_config = distributed::DistConfig(),
        infinicore::Device::Type device_type = infinicore::context::getDevice().getType());

    // Load a parameter to all workers (each can extract its shard inside RankWorker)
    void load_param(const std::string &name, const infinicore::Tensor &param);

22
23
24
    // return the parameters (i.e. weights and biases).
    std::unordered_map<std::string, infinicore::nn::Parameter> state_dict();

25
26
    // Run a single forward pass on all workers and return the outputs from all ranks
    infinicore::Tensor generate(const infinicore::Tensor &input_ids,
27
                                const infinicore::Tensor &position_ids);
28
29
30
31
32
33
34
35
36
37
38
39

    ~InferEngine();

    const distributed::DistConfig &get_dist_config() const;

protected:
    std::vector<std::unique_ptr<RankWorker>> workers_;
    distributed::CommunicationGroup communication_group_;
    std::any model_config_;
};

} // namespace infinilm::engine