#include "infer_engine.hpp" #include "spdlog/spdlog.h" namespace infinilm::engine { //------------------------------------------------------ // Constructor //------------------------------------------------------ /** * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0). * * ⚠️ DEVELOPMENT POLICY: * - NO new development or feature additions permitted on this interface * - Only critical bug fixes (security/stability) allowed until removal * - All new code MUST migrate to the polymorphic overload below * * Replacement: Use the polymorphic overload of this same function name with updated signature * Reason: Legacy signature lacks support for dynamic quantization modes. * Removal target: v0.2.0 (Q2 2026) */ InferEngine::InferEngine( const InfinilmModel::Config &config, const distributed::DistConfig &distributed_config, infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, backends::AttentionBackend attention_backend) // Changed parameter : communication_group_(distributed_config, device_type), legacy_model_config_(config), attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); } // Create one RankWorker per rank int world_size = communication_group_.get_world_size(); barrier_ = std::make_unique((size_t)world_size); workers_.reserve(world_size); for (int r = 0; r < world_size; ++r) { workers_.emplace_back(std::make_unique( legacy_model_config_, communication_group_.get_rank_info(r), cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), enable_graph_compiling, attention_backend_)); } // Compile the model on all workers this->compile(); } InferEngine::InferEngine( const std::string &model_path, const distributed::DistConfig &distributed_config, infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, backends::AttentionBackend attention_backend) // Changed parameter : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); } // Load model config if model_path is provided, model_path must be valid, and config.json exists this->model_config_ = std::make_shared(model_path + "/config.json"); // Create one RankWorker per rank int world_size = communication_group_.get_world_size(); barrier_ = std::make_unique((size_t)world_size); workers_.reserve(world_size); for (int r = 0; r < world_size; ++r) { workers_.emplace_back(std::make_unique( model_config_, communication_group_.get_rank_info(r), cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), enable_graph_compiling, attention_backend_)); } // Compile the model on all workers this->compile(); } //------------------------------------------------------ // load_param //------------------------------------------------------ void InferEngine::load_param(const std::string &name, const infinicore::Tensor ¶m) { // Load the parameter on all workers for (auto &worker : workers_) { worker->load_param(name, param); } } //------------------------------------------------------ // state_dict //------------------------------------------------------ std::vector> InferEngine::state_dict() { std::vector> results; if (0 == workers_.size()) { throw std::runtime_error(" Model object not found. "); } for (auto &worker : workers_) { results.push_back(worker->state_dict()); } return results; } //------------------------------------------------------ // forward //------------------------------------------------------ infinilm::InfinilmModel::Input InferEngine::Input::to_model_input(infinicore::Device device) const { auto to_device = [&](const std::optional &t) -> std::optional { return t.has_value() ? t.value()->to(device) : t; }; return { to_device(input_ids), // @todo: on device in the future to_device(position_ids), to_device(past_sequence_lengths), // @todo: on device in the future to_device(total_sequence_lengths), to_device(input_offsets), to_device(cu_seqlens), to_device(block_tables), to_device(slot_mapping), }; } InferEngine::Output InferEngine::forward(const InferEngine::Input &input) { // Trigger each worker to run inference for (auto &worker : workers_) { worker->run(input); } // Wait for all workers for (auto &worker : workers_) { worker->wait(); } return workers_[0]->get_output(); } void InferEngine::compile() { for (auto &worker : workers_) { worker->compile(); } // Wait for all workers for (auto &worker : workers_) { worker->wait(); } } //------------------------------------------------------ // Destructor //------------------------------------------------------ InferEngine::~InferEngine() { // Close all workers for (auto &worker : workers_) { worker->close(); } } const distributed::DistConfig &InferEngine::get_dist_config() const { return communication_group_.get_dist_config(); } //------------------------------------------------------ // reset_cache (overloaded with CacheConfig) //------------------------------------------------------ void InferEngine::reset_cache(const cache::CacheConfig *new_config) { for (auto &worker : workers_) { worker->reset_cache(new_config); } for (auto &worker : workers_) { worker->wait(); } cache_config_ = new_config->unique_copy(); this->compile(); } } // namespace infinilm::engine