#pragma once #include "../models/infinilm_model.hpp" #include "../models/llama/llama_config.hpp" #include "distributed/distributed.hpp" #include "infinicore/tensor.hpp" #include "rank_worker.hpp" #include #include namespace infinilm::engine { class InferEngine { public: using Input = RankWorker::Input; using Output = RankWorker::Output; // Updated constructor: accept CacheConfig instead of CacheType InferEngine( const InfinilmModel::Config &config, const distributed::DistConfig &distributed_config = distributed::DistConfig(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); // return the parameters (i.e. weights and biases). std::vector> state_dict(); // Run a single forward pass on all workers and return the outputs from all ranks Output forward(const Input &input); void reset_cache(const cache::CacheConfig *new_config); ~InferEngine(); const distributed::DistConfig &get_dist_config() const; // Get current KV configuration const cache::CacheConfig *get_cache_config() const { return cache_config_.get(); } protected: std::vector> workers_; distributed::CommunicationGroup communication_group_; const InfinilmModel::Config &model_config_; std::unique_ptr cache_config_; }; } // namespace infinilm::engine