infer_engine.hpp 1.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#pragma once

#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_worker.hpp"

#include <any>
#include <vector>

namespace infinilm::engine {

class InferEngine {
public:
    InferEngine(
        const std::any &config,
        const distributed::DistConfig &distributed_config = distributed::DistConfig(),
        infinicore::Device::Type device_type = infinicore::context::getDevice().getType());

    // Load a parameter to all workers (each can extract its shard inside RankWorker)
    void load_param(const std::string &name, const infinicore::Tensor &param);

    // Run a single forward pass on all workers and return the outputs from all ranks
    infinicore::Tensor generate(const infinicore::Tensor &input_ids,
                               const infinicore::Tensor &position_ids);

    ~InferEngine();

    const distributed::DistConfig &get_dist_config() const;

protected:
    std::vector<std::unique_ptr<RankWorker>> workers_;
    distributed::CommunicationGroup communication_group_;
    std::any model_config_;
};

} // namespace infinilm::engine