// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::{ collections::HashMap, ops::Deref, path::Path, process::Stdio, sync::Arc, time::Duration, vec::IntoIter, }; use async_zmq::{SinkExt, StreamExt}; use dynemo_runtime::protocols::annotated::Annotated; use dynemo_runtime::CancellationToken; use pyo3::{ prelude::*, types::{IntoPyDict, PyBytes, PyString}, }; use tokio::sync::mpsc::Sender; use tokio::task::JoinHandle; use tokio::{io::AsyncBufReadExt, sync::mpsc::error::SendError}; use crate::engines::MultiNodeConfig; use crate::protocols::common::llm_backend::LLMEngineOutput; use crate::protocols::common::preprocessor::PreprocessedRequest; use crate::protocols::common::FinishReason; /// If user does not provide a max_tokens limit to this many const DEFAULT_MAX_TOKENS: u32 = 8192; /// Wait this long for the vllm sub-process to stop after we send it a KILL const VLLM_STOP_TIMEOUT: Duration = Duration::from_millis(1500); type RequestID = String; pub struct VllmWorker { /// How we receive work requests tx: Sender, /// Handle of the task that reads from `tx` and forwards those requests over zmq to vllm _input_loop: JoinHandle<()>, /// Handle of the task that reads vllm's responses from zmq and dispatches them to the correct /// active request. _output_loop: JoinHandle<()>, /// Handle of the vllm background process vllm: Option>, // We don't need to hold on to this, it's already shared between input_loop and output_loop // But later we'll probably want stats - how many active requests etc, so keep it here _active_requests: Arc>>, // Need to keep this alive // TODO: With async_zmq we possibly don't need this at all #[allow(dead_code)] zmq_context: async_zmq::Context, } /// How we get asked to do some work. These get unpacked and forwarded to vllm. pub struct WorkRequest { pub request: PreprocessedRequest, pub request_id: RequestID, pub response_channel: Sender>, } /// A request currently being process by vllm struct ActiveRequest { tx: Sender>, num_output_tokens_so_far: usize, max_tokens: usize, } /// Python imports struct Imports { pickle_module: PyObject, tokens_prompt_type: PyObject, sample_params_type: PyObject, rpc_type: PyObject, startup_type: PyObject, } /// All the zmq sockets we used. This object only used to passing them around to avoid large /// tuples. struct Sockets { #[allow(dead_code)] context: async_zmq::Context, // we have to keep this alive // Control socket, how we ask vllm engine to start. // Not the best name, but this is what vllm calls it internally. data: async_zmq::Dealer>, Vec>, // Requests from us to the vllm engine input: async_zmq::Push>, Vec>, // Responses from the vllm engine back to us output: async_zmq::Pull, // Heartbeat messages from vllm process heartbeat: async_zmq::Pull, } /// The message vllm sends us over zmq when it's ready to work. #[derive(FromPyObject, Debug)] struct RPCStartupResponse { #[allow(dead_code)] tracing_enabled: bool, } /// What vllm sends us. Usually it contains a single token. #[allow(dead_code)] #[derive(FromPyObject, Debug)] pub struct RequestOutput { request_id: String, prompt: Option, prompt_token_ids: Option>, prompt_logprobs: Option>>>, outputs: Vec, finished: bool, //metrics: Optional[RequestMetrics] = None, //lora_request: Optional[LoRARequest] = None, encoder_prompt: Option, encoder_prompt_token_ids: Option>, num_cached_tokens: Option, } #[allow(dead_code)] #[derive(FromPyObject, Debug)] pub struct CompletionOutput { index: u32, text: String, token_ids: Vec, cumulative_logprob: Option, logprobs: Option>>, finish_reason: Option, //stop_reason: Union[int, str, None] = None //lora_request: Optional[LoRARequest] = None } #[allow(dead_code)] #[derive(FromPyObject, Debug)] struct Logprob { logprob: f32, rank: Option, decoded_token: Option, } /// Main entry point pub async fn start( cancel_token: CancellationToken, sock_code: &str, card_path: &Path, model_path: &Path, _node_conf: MultiNodeConfig, tensor_parallel_size: u32, ) -> anyhow::Result { pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize" let py_imports = Arc::new(python_imports()); let Sockets { context, data, input, output, heartbeat, } = zmq_sockets(sock_code)?; let vllm_process = start_vllm( card_path, model_path, &py_imports, data, tensor_parallel_size, ) .await?; let vllm_join_handle = watch_vllm(cancel_token.clone(), vllm_process); tokio::spawn(heartbeat_loop(cancel_token.clone(), heartbeat)); let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new())); let (tx, rx) = tokio::sync::mpsc::channel(8); let input_loop_handle = { let cancel_token = cancel_token.clone(); let py_imports = py_imports.clone(); let active_requests = active_requests.clone(); tokio::spawn(input_loop( cancel_token, py_imports, input, active_requests, rx, )) }; let output_loop_handle = { let cancel_token = cancel_token.clone(); let py_imports = py_imports.clone(); let active_requests = active_requests.clone(); tokio::spawn(output_loop( cancel_token, py_imports, output, active_requests, )) }; Ok(VllmWorker { tx, zmq_context: context, _input_loop: input_loop_handle, _output_loop: output_loop_handle, vllm: Some(vllm_join_handle), _active_requests: active_requests, }) } /// Import all the python packages we'll need. `vllm` particularly takes a few seconds. fn python_imports() -> Imports { Python::with_gil(|py| { let pickle_module: PyObject = match py.import("pickle") { Ok(m) => m.into(), Err(err) => { // There is no vllm without python panic!("Failed to import python 'pickle' module. Is Python installed? {err}"); } }; let vllm_module: PyObject = match py.import("vllm") { Ok(m) => m.into(), Err(err) => { panic!("Failed to import python 'vllm' module. Are we running in the correct venv? {err}"); } }; let tokens_prompt_type: PyObject = vllm_module.getattr(py, "TokensPrompt").unwrap(); let sample_params_type: PyObject = vllm_module.getattr(py, "SamplingParams").unwrap(); let mod_multiprocessing = py.import("vllm.engine.multiprocessing").unwrap(); let rpc_type: PyObject = mod_multiprocessing .getattr("RPCProcessRequest") .unwrap() .into(); let startup_type: PyObject = mod_multiprocessing .getattr("RPCStartupRequest") .unwrap() .into(); Imports { pickle_module, tokens_prompt_type, sample_params_type, rpc_type, startup_type, } }) } /// Create all the zmq sockets we're going to use. fn zmq_sockets(sock_code: &str) -> anyhow::Result { let zmq_context = async_zmq::Context::new(); let input = async_zmq::push(&format!("ipc:///tmp/{sock_code}_input_socket"))? .with_context(&zmq_context) .connect()?; let output = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_output_socket"))? .with_context(&zmq_context) .connect()?; let data = async_zmq::dealer(&format!("ipc:///tmp/{sock_code}_data_socket"))? .with_context(&zmq_context) .connect()?; let heartbeat = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_health_socket"))? .with_context(&zmq_context) .connect()?; Ok(Sockets { context: zmq_context, data, input, output, heartbeat, }) } /// Start the vllm python sub-process and wait for it to start async fn start_vllm( card_path: &Path, model_path: &Path, python_imports: &Imports, mut data_socket: async_zmq::Dealer>, Vec>, tensor_parallel_size: u32, ) -> anyhow::Result { let vllm_args = [ "--internal-vllm-process", &format!("--model-config={}", card_path.display()), &format!("--model-path={}", model_path.display()), &format!("--tensor-parallel-size={tensor_parallel_size}"), ]; let self_path = std::env::current_exe()?; let mut proc = tokio::process::Command::new(self_path) .env("VLLM_LOGGING_LEVEL", "DEBUG") .args(vllm_args) .kill_on_drop(false) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn()?; let stdout = tokio::io::BufReader::new(proc.stdout.take().unwrap()); let stderr = tokio::io::BufReader::new(proc.stderr.take().unwrap()); tokio::spawn(async move { let mut lines = stdout.lines(); while let Ok(Some(line)) = lines.next_line().await { let mut line_parts = line.splitn(4, ' '); let log_level = line_parts.next().unwrap_or_default(); // Skip date (0) and time (1). Print last (2) which is everything else. let line = line_parts.nth(2).unwrap_or_default(); if line.starts_with("custom_op.py:68") { // Skip a noisy line // custom_op.py:68] custom op enabled continue; } match log_level { "DEBUG" => tracing::debug!("VLLM: {line}"), "INFO" => tracing::info!("VLLM: {line}"), "WARNING" => tracing::warn!("VLLM: {line}"), level => tracing::info!("VLLM: {level} {line}"), } } }); tokio::spawn(async move { let mut lines = stderr.lines(); while let Ok(Some(line)) = lines.next_line().await { tracing::warn!("VLLM: {line}"); } }); let start_req_bytes: Vec = Python::with_gil(|py| { let start_req = python_imports .startup_type .getattr(py, "IS_SERVER_READY") .unwrap(); let pickle_dumps = python_imports.pickle_module.getattr(py, "dumps").unwrap(); pickle_dumps .call1(py, (start_req,)) .unwrap() .extract(py) .unwrap() }); data_socket.send(vec![start_req_bytes].into()).await?; let start_resp: Vec = match data_socket.next().await { Some(Ok(r)) => { if !r.is_empty() { r[0].deref().to_vec() } else { anyhow::bail!("vllm failed to start. No response on dealer/data socket"); } } Some(Err(err)) => { anyhow::bail!("vllm failed to start. Error reading from dealer/data socket: {err}"); } None => { anyhow::bail!("vllm failed to start. dealer/data socket is closed."); } }; let resp: RPCStartupResponse = Python::with_gil(|py| { let pickle_loads = python_imports.pickle_module.getattr(py, "loads").unwrap(); pickle_loads .call1(py, (start_resp,)) .unwrap() .extract(py) .unwrap() }); tracing::info!("vllm zmq backend is ready: {resp:?}"); Ok(proc) } // Stop the vllm process when we stop, and prevent it going zombie. fn watch_vllm( cancel_token: CancellationToken, mut vllm_process: tokio::process::Child, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { cancel_token.cancelled().await; tokio::select! { _ = vllm_process.wait() => { return; }, _ = tokio::time::sleep(VLLM_STOP_TIMEOUT) => { } } if let Err(err) = vllm_process.start_kill() { tracing::error!("Failing killing vllm subprocess: {err}"); return; } tokio::select! { _ = vllm_process.wait() => { }, _ = tokio::time::sleep(VLLM_STOP_TIMEOUT) => { tracing::warn!("Timeout waiting for vllm sub-process to stop after kill"); } } }) } // How we know vllm engine is alive. It sends "SUCCESS" as a pickled string every 10s. // Runs outside of tokio on a regular thread. // TODO: If we don't get heartbeats we should, euh, do something. vllm is gone. At least // de-register the model. async fn heartbeat_loop(cancel_token: CancellationToken, mut socket: async_zmq::Pull) { loop { let maybe_hb = tokio::select! { _ = cancel_token.cancelled() => { break; } maybe_hb = socket.next() => { maybe_hb } }; let b = match maybe_hb { Some(Ok(b)) => b[0].deref().to_vec(), Some(Err(err)) => { tracing::error!("Error reading from vllm heartbeat socket: {err}"); break; } None => { tracing::debug!("vllm heartbeat socket closed"); break; } }; let s: String = match serde_pickle::from_slice(&b, Default::default()) { Ok(s) => s, Err(err) => { tracing::error!("Error de-serializing vllm heartbeat response. It was probably Exception not str. {err}"); break; } }; if s != "SUCCESS" { tracing::error!("vllm heartbeat error, expected 'SUCCESS' got '{s}'"); break; } } } fn from_vllm(output: CompletionOutput, previous_total_toks: usize) -> LLMEngineOutput { let finish_reason = match output.finish_reason.as_deref() { Some("stop") => Some(FinishReason::Stop), Some("abort") => Some(FinishReason::Cancelled), Some("length") => Some(FinishReason::Length), Some(unknown) => { tracing::info!("Unknown vllm stop reason '{unknown}'. Please add to vllm.rs"); Some(FinishReason::Stop) } None => None, }; LLMEngineOutput { // todo - propagate mdcsum token_ids: output.token_ids[previous_total_toks..].into(), tokens: None, text: None, //text: if output.text.is_empty() { None } else { Some(output.text) }, cum_log_probs: output.cumulative_logprob.map(|v| v as f64), log_probs: None, // TODO output.logprobs finish_reason, } } async fn input_loop( cancel_token: CancellationToken, py_imports: Arc, mut input_socket: async_zmq::Push>, Vec>, active_requests: Arc>>, mut rx: tokio::sync::mpsc::Receiver, ) { loop { let work_request = tokio::select! { _ = cancel_token.cancelled() => { tracing::trace!("VllmWorker.input_loop exit"); break; } req = rx.recv() => { match req { Some(req) => req, None => { tracing::trace!("VllmWorker input_loop socket closed"); break; } } } }; let request_id = work_request.request_id; let token_ids = work_request.request.token_ids.clone(); let temperature: f64 = work_request .request .sampling_options .temperature .unwrap_or(0.0) .into(); let max_tokens = work_request .request .stop_conditions .max_tokens .unwrap_or(DEFAULT_MAX_TOKENS); // Parts that don't change let (py_request_id, sampling_params) = Python::with_gil(|py| { let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into(); let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into(); let sp_kwargs = [("temperature", py_temp), ("max_tokens", py_max_tokens)] .into_py_dict(py) .unwrap(); let sampling_params = py_imports .sample_params_type .call(py, (), Some(&sp_kwargs)) .unwrap(); let py_request_id: PyObject = PyString::new(py, &request_id).into(); (py_request_id, sampling_params) }); let pickled_req: Vec = Python::with_gil(|py| { let token_prompt_kwargs = [("prompt_token_ids", token_ids.clone())] .into_py_dict(py) .unwrap(); let prompt_obj = py_imports .tokens_prompt_type .call(py, (), Some(&token_prompt_kwargs)) .unwrap(); let rpc_kwargs = [ ("prompt", prompt_obj), ("params", sampling_params.clone()), ("request_id", py_request_id.clone()), ] .into_py_dict(py) .unwrap(); let req = py_imports.rpc_type.call(py, (), Some(&rpc_kwargs)).unwrap(); let pickle_dumps = py_imports.pickle_module.getattr(py, "dumps").unwrap(); pickle_dumps.call1(py, (req,)).unwrap().extract(py).unwrap() }); let new_active_request = ActiveRequest { tx: work_request.response_channel, max_tokens: max_tokens as usize, num_output_tokens_so_far: 0, }; active_requests .lock() .await .insert(request_id, new_active_request); if let Err(err) = input_socket.send(vec![pickled_req].into()).await { tracing::error!("Error sending new request to vllm over zmq: {err}"); } } } /// Read from vllm's output zmq socket, find which request it is for and forward over that channel. async fn output_loop( cancel_token: CancellationToken, py_imports: Arc, mut output_socket: async_zmq::Pull, active_requests: Arc>>, ) { loop { let mut bb = tokio::select! { _ = cancel_token.cancelled() => { tracing::trace!("VllmWorker.output_loop exit"); break; } from_vllm = output_socket.next() => { match from_vllm { Some(Ok(b)) => b, Some(Err(err)) => { tracing::error!("Error reading from vllm zmq output: {err}"); continue; // hope lives eternal } None => { tracing::debug!("zmq output socket closed"); break; } } } }; let frame = bb.remove(0); let mut reqs_out: Vec = Python::with_gil(|py| { let pickle_loads = py_imports.pickle_module.getattr(py, "loads").unwrap(); let frame_bytes = PyBytes::new(py, &frame); pickle_loads .call1(py, (frame_bytes,)) .unwrap() .extract(py) .unwrap() }); if reqs_out.is_empty() { tracing::debug!("Received message from vllm with no content"); continue; } let req_out = reqs_out.remove(0); if req_out.finished { // The last token is the eos_token, don't forward it let out = Annotated::from_data(LLMEngineOutput::stop()); let maybe_active = active_requests.lock().await.remove(&req_out.request_id); match maybe_active { Some(active) => { let _ = active.tx.send(out).await; } None => { tracing::warn!( req_out.request_id, "Missing active request to notify of stop" ); } } continue; } let mut remove_after = false; for vllm_output in req_out.outputs.into_iter() { let next_total_toks = vllm_output.token_ids.len(); match active_requests.lock().await.get_mut(&req_out.request_id) { Some(active) => { let out = from_vllm(vllm_output, active.num_output_tokens_so_far); active.num_output_tokens_so_far = next_total_toks; let out = if active.num_output_tokens_so_far <= active.max_tokens { Annotated::from_data(out) } else { // we exceeded max tokens, this request is over remove_after = true; Annotated::from_data(LLMEngineOutput::length()) }; let _ = active.tx.send(out).await; } None => { tracing::warn!(req_out.request_id, "Missing active request"); } } } if remove_after { let _ = active_requests.lock().await.remove(&req_out.request_id); } } } impl VllmWorker { /// Send a request to vllm pub async fn enqueue_request(&self, r: WorkRequest) -> Result<(), SendError> { self.tx.send(r).await } /// Get the vllm sub-process handle, so we can await it and prevent it going zombie. pub fn take_vllm_handle(&mut self) -> JoinHandle<()> { self.vllm.take().unwrap() } }