#include "rank_worker.hpp" #include "../models/model_factory.hpp" #include "infinicore/ops.hpp" #include #include #include namespace infinilm::engine { RankWorker::RankWorker(const InfinilmModel::Config &model_config, const distributed::RankInfo &rank_info, const cache::CacheConfig *cache_config) : model_config_(model_config), rank_info_(rank_info), job_cmd_(Command::INIT), has_job_(false), job_done_(false), should_exit_(false), init_done_(false), rng_(std::random_device{}()) { if (cache_config != nullptr) { pending_cache_config_ = cache_config->unique_copy(); } // start the thread thread_ = std::thread(&RankWorker::thread_loop, this); // Wait until the worker thread finishes initialization (model created) std::unique_lock lk(mutex_); cv_.wait(lk, [&] { return init_done_; }); } std::string RankWorker::info() const { std::stringstream ss; ss << "RankWorker{"; // Rank related ss << rank_info_.to_string() << " "; // Flags ss << "| init_done: " << (init_done_ ? "true" : "false") << " "; ss << "| should_exit: " << (should_exit_ ? "true" : "false") << " "; ss << "| has_job: " << (has_job_ ? "true" : "false") << " "; ss << "| job_done: " << (job_done_ ? "true" : "false") << " "; ss << "}"; return ss.str(); } //------------------------------------------------------ // load_param -- synchronous (blocks until worker finishes loading) //------------------------------------------------------ void RankWorker::load_param(const std::string &name, const infinicore::Tensor ¶m) { { std::lock_guard lock(mutex_); // If the worker is stopping, don't submit new jobs. if (should_exit_) { throw std::runtime_error("RankWorker is closing; cannot load_param"); } pending_param_name_ = name; pending_param_ = param; job_cmd_ = Command::LOAD; has_job_ = true; job_done_ = false; } cv_.notify_all(); // Wait for job completion std::unique_lock lk(mutex_); cv_.wait(lk, [&] { return job_done_ || should_exit_; }); if (should_exit_) { throw std::runtime_error("RankWorker stopped while loading parameter"); } } //------------------------------------------------------ // state_dict -- //------------------------------------------------------ std::unordered_map RankWorker::state_dict() { std::unique_lock lk(mutex_); cv_.wait(lk, [&] { return init_done_ || should_exit_; }); if (!model_) { throw std::runtime_error("state_dict called before model initialization"); } return model_->state_dict(); } //------------------------------------------------------ // run -- asynchronous //------------------------------------------------------ void RankWorker::run(const Input &args) { std::lock_guard lock(mutex_); if (should_exit_) { throw std::runtime_error("RankWorker is closing; cannot run"); } pending_args_ = args; job_cmd_ = Command::RUN; has_job_ = true; job_done_ = false; cv_.notify_all(); } //------------------------------------------------------ // wait -- asynchronous //------------------------------------------------------ void RankWorker::wait() { std::unique_lock lk(mutex_); cv_.wait(lk, [&] { return job_done_ || should_exit_; }); if (should_exit_) { throw std::runtime_error("RankWorker stopped during run"); } } void RankWorker::reset_cache(const cache::CacheConfig *new_config) { std::lock_guard lock(mutex_); if (should_exit_) { throw std::runtime_error("RankWorker is closing; cannot reset_cache"); } // Store both the position and the new config pending_cache_config_ = new_config->unique_copy(); job_cmd_ = Command::RESET_CACHE; has_job_ = true; job_done_ = false; cv_.notify_all(); } //------------------------------------------------------ // close -- request shutdown and join thread //------------------------------------------------------ void RankWorker::close() { { std::lock_guard lock(mutex_); should_exit_ = true; has_job_ = false; // don't keep old jobs pending job_cmd_ = Command::STOP; } cv_.notify_all(); if (thread_.joinable()) { thread_.join(); } } //------------------------------------------------------ // get_output (thread safe) //------------------------------------------------------ RankWorker::Output RankWorker::get_output() { std::lock_guard lock(mutex_); return output_; } //------------------------------------------------------ // thread_loop //------------------------------------------------------ void RankWorker::thread_loop() { try { { std::lock_guard lk(mutex_); // Initialize device & model outside of holding the main mutex to avoid blocking callers. infinicore::context::setDevice(rank_info_.device); // Create model using factory (may be expensive) model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); if (!model_) { throw std::runtime_error("Failed to create model"); } init_done_ = true; } cv_.notify_all(); // Main loop: wait for jobs or exit while (true) { Command local_cmd = Command::INIT; std::string local_param_name; infinicore::Tensor local_param; Input local_args; std::unique_ptr local_cache_config; // Wait for a job or exit { std::unique_lock lk(mutex_); cv_.wait(lk, [&] { return has_job_ || should_exit_; }); if (should_exit_) { break; } // capture job data and clear has_job_ local_cmd = job_cmd_; if (local_cmd == Command::LOAD) { local_param_name = pending_param_name_; local_param = pending_param_; } else if (local_cmd == Command::RUN) { local_args = pending_args_; } else if (local_cmd == Command::RESET_CACHE) { if (pending_cache_config_ != nullptr) { local_cache_config = pending_cache_config_->unique_copy(); } } // mark job as being processed has_job_ = false; job_done_ = false; } // unlock mutex while executing the job // Execute job outside the lock if (local_cmd == Command::LOAD) { try { model_->load_parameter(local_param_name, local_param); } catch (const std::exception &e) { // convert exceptions to a safe behavior: set should_exit_ and notify caller std::lock_guard lk(mutex_); should_exit_ = true; job_done_ = true; cv_.notify_all(); // rethrow so the thread can be joined and caller sees an error if desired (optional) spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what()); break; } // signal completion { std::lock_guard lk(mutex_); job_done_ = true; } cv_.notify_all(); } else if (local_cmd == Command::RUN) { try { { std::lock_guard lk(mutex_); auto model_args = local_args.to_model_input(rank_info_.device); // Forward calculation auto logits{model_->forward(model_args).logits}; // Random sampling (rank 0 only) if (rank_info_.tp_rank == 0) { auto temperature{local_args.temperature}; auto top_p{local_args.top_p}; auto top_k{local_args.top_k}; const auto &logits_shape{logits->shape()}; const auto &vocab_size{logits_shape[2]}; const auto &total_len{logits_shape[1]}; const auto &batch_size{logits_shape[0]}; auto n_req = local_args.input_offsets.value()->size(0) - 1; int64_t *input_offsets = (int64_t *)local_args.input_offsets.value()->data(); auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)}; for (auto i{decltype(n_req)(0)}; i < n_req; ++i) { auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i + 1] - 1), 1}})->view({vocab_size})}; auto out{output_ids->narrow({{0, i, 1}})->view({})}; float random_val = std::uniform_real_distribution(0, 1)(rng_); infinicore::op::random_sample_( out, score, random_val, top_p, top_k, temperature); } output_ids = output_ids->to(infinicore::Device::cpu()); infinicore::context::syncStream(); auto out{Output{output_ids}}; output_ = std::move(out); } job_done_ = true; } cv_.notify_all(); } catch (const std::exception &e) { std::lock_guard lk(mutex_); should_exit_ = true; job_done_ = true; cv_.notify_all(); spdlog::error("[{}] exception during forward: {}\n", info(), e.what()); break; } } else if (local_cmd == Command::RESET_CACHE) { try { model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr); { std::lock_guard lk(mutex_); job_done_ = true; } cv_.notify_all(); } catch (const std::exception &e) { std::lock_guard lk(mutex_); should_exit_ = true; job_done_ = true; cv_.notify_all(); spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what()); break; } } else { // Shouldn't reach here (no-op) } } // while } catch (const std::exception &e) { // Top-level exception: ensure any waiters are woken and the thread exits cleanly. { std::lock_guard lk(mutex_); should_exit_ = true; job_done_ = true; } cv_.notify_all(); spdlog::error("[{}] fatal exception in thread_loop: {} \n", info(), e.what()); } } } // namespace infinilm::engine