#pragma once #include "distributed/distributed.hpp" #include "infinicore/tensor.hpp" #include "rank_worker.hpp" #include #include namespace infinilm::engine { class InferEngine { public: InferEngine( const std::any &config, const distributed::DistConfig &distributed_config = distributed::DistConfig(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType()); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); // return the parameters (i.e. weights and biases). std::unordered_map state_dict(); // Run a single forward pass on all workers and return the outputs from all ranks infinicore::Tensor generate(const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids); ~InferEngine(); const distributed::DistConfig &get_dist_config() const; protected: std::vector> workers_; distributed::CommunicationGroup communication_group_; std::any model_config_; }; } // namespace infinilm::engine