#pragma once #include "../backends/attention_backends.hpp" #include "../cache/cache.hpp" #include "../config/model_config.hpp" #include "../models/model_factory.hpp" #include "compiler/general_compiler.hpp" #include "distributed/distributed.hpp" #include "rank_barrier.hpp" #include #include #include #include #include #include #include namespace infinilm::engine { class RankWorker { enum class Command { INIT, LOAD, RUN, RESET_CACHE, COMPILE, STOP }; public: struct Input { /// Token IDs tensor of shape `[batch, seq_len]`. std::optional input_ids; /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`. std::optional position_ids; /// Past Lengths of cached sequence for each request, of shape `[num_requests]`. std::optional past_sequence_lengths; /// ToTal Lengths for each request sequence, of shape `[num_requests]`. std::optional total_sequence_lengths; /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`. std::optional input_offsets; /// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`. std::optional cu_seqlens; /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. std::optional block_tables; /// Slot ids for each token `[seq]`. Used for paged cache. std::optional slot_mapping; float temperature{1}; int top_k{50}; float top_p{1}; infinilm::InfinilmModel::Input to_model_input(infinicore::Device device) const; }; struct Output { infinicore::Tensor output_ids; }; RankWorker(const InfinilmModel::Config &model_config, const distributed::RankInfo &rank_info, const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, backends::AttentionBackend attention_backend); RankWorker(std::shared_ptr model_config, const distributed::RankInfo &rank_info, const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, backends::AttentionBackend attention_backend); // Submit a parameter load job and wait until the load completes on the worker thread. void load_param(const std::string &name, const infinicore::Tensor ¶m); // return the parameters (i.e. weights and biases). std::unordered_map state_dict(); // Submit a run (forward) job. void run(const Input &args); // Reset the internal cache with a new configuration void reset_cache(const cache::CacheConfig *new_config); // Compile the model graph if enabled. void compile(); // Wait until run job completes. The result can be retrieved with get_output(). void wait(); // Request worker shutdown and join the thread. void close(); // Thread-safe accessor for last output produced by RUN. Output get_output(); std::string info() const; private: void thread_loop(); private: // Worker properties const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config(); std::shared_ptr model_config_; distributed::RankInfo rank_info_; std::shared_ptr model_; std::shared_ptr cache_; // Backends backends::AttentionBackend attention_backend_; // Graph Compiling bool enable_graph_compiling_; std::unique_ptr compiler_; // Command for the pending job (protected by mutex_) Command job_cmd_; // State flags (protected by mutex_) bool has_job_ = false; // a job is pending bool job_done_ = false; // last job completed bool should_exit_ = false; // request to stop bool init_done_ = false; // initialization finished // Task payloads (protected by mutex) std::string pending_param_name_; infinicore::Tensor pending_param_; Input pending_args_; std::unique_ptr pending_cache_config_; // Output (protected by mutex) Output output_; // Thread sync std::thread thread_; std::mutex mutex_; std::condition_variable cv_; // Random std::mt19937 rng_; RankBarrier *barrier_; }; } // namespace infinilm::engine