#pragma once #include "../models/infinilm_model.hpp" #include "../models/llama/llama_config.hpp" #include "distributed/distributed.hpp" #include "infinicore/tensor.hpp" #include "rank_worker.hpp" #include #include namespace infinilm::engine { class InferEngine { public: struct Input { /// Token IDs tensor of shape `[batch, seq_len]`. std::optional input_ids; /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`. std::optional position_ids; /// Past Lengths of cached sequence for each request, of shape `[num_requests]`. std::optional cache_lengths; /// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`. std::optional input_lengths; /// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`. std::optional input_offsets; /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. std::optional block_tables; /// Slot ids for each token `[seq]`. Used for paged cache. std::optional slot_mapping; infinilm::InfinilmModel::Input to_model_input() const; }; struct Output { infinicore::Tensor logits; }; // Updated constructor: accept CacheConfig instead of CacheType InferEngine( const InfinilmModel::Config &config, const distributed::DistConfig &distributed_config = distributed::DistConfig(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); // return the parameters (i.e. weights and biases). std::vector> state_dict(); // Run a single forward pass on all workers and return the outputs from all ranks Output forward(const Input &input); void reset_cache(const cache::CacheConfig *new_config); ~InferEngine(); const distributed::DistConfig &get_dist_config() const; // Get current KV configuration const cache::CacheConfig *get_cache_config() const { return cache_config_.get(); } protected: std::vector> workers_; distributed::CommunicationGroup communication_group_; const InfinilmModel::Config &model_config_; std::unique_ptr cache_config_; }; } // namespace infinilm::engine