Unverified Commit 42969800 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Remove embedded Python vllm and sglang engines (#966)

vllm and sglang are now the sub-process engines from #954

Also updated docs on doing vllm and sglang multi-gpu (tensor parallel) and multi-node (pipeline parallel).
parent 5d89a0c8
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::{Path, PathBuf};
use std::sync::Arc;
use dynamo_llm::backend::ExecutionContext;
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::CancellationToken;
use pyo3::prelude::*;
mod worker;
mod engine;
use engine::SgLangEngine;
mod subprocess;
pub use subprocess::run_subprocess;
pub async fn make_engine(
cancel_token: CancellationToken,
// Full path to the model directory
model_path: &Path,
// Unique string to name zmq sockets
sock_code: &str,
// Multi node settings
node_conf: dynamo_llm::engines::MultiNodeConfig,
// How many GPUs to use
tensor_parallel_size: u32,
// The base GPU ID to start allocating GPUs from
base_gpu_id: u32,
// Extra arguments to pass directly as sglang ServerArgs
extra_engine_args: Option<PathBuf>,
) -> pipeline_error::Result<(ExecutionContext, tokio::task::JoinHandle<()>)> {
let mut engine = SgLangEngine::new(
cancel_token,
sock_code,
model_path,
node_conf,
tensor_parallel_size,
base_gpu_id,
extra_engine_args,
)
.await?;
let sglang_process = engine.take_sglang_worker_handle();
let engine: ExecutionContext = Arc::new(engine);
Ok((engine, sglang_process))
}
#[derive(Debug, Clone, Copy)]
pub struct MultiGPUConfig {
/// How many GPUs we are using / how many processes
pub tp_size: u32,
/// Tensor Parallel Rank. Must be unique across all nodes and GPUs.
pub tp_rank: u32,
/// GPU ID. Which GPU to run on. In single-node setup this is the same as tp_rank.
pub gpu_id: u32,
}
impl Default for MultiGPUConfig {
fn default() -> Self {
MultiGPUConfig {
tp_size: 1,
tp_rank: 0,
gpu_id: 0,
}
}
}
#[cfg(target_os = "macos")]
fn fix_venv(venv: String, py: Python<'_>) -> anyhow::Result<()> {
let version_info = py.version_info();
let sys: PyObject = py.import("sys")?.into();
let sys_path = sys.getattr(py, "path")?;
let venv_path = format!(
"{venv}/lib/python{}.{}/site-packages",
version_info.major, version_info.minor
);
// TODO: This should go _before_ the site-packages
sys_path.call_method1(py, "append", (venv_path,))?;
Ok(())
}
#[cfg(not(target_os = "macos"))]
fn fix_venv(_venv: String, _py: Python<'_>) -> anyhow::Result<()> {
Ok(())
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# This file is included as a string in subprocess.rs. Most work should be done in the Rust caller.
#
import json
import logging
import tempfile
from multiprocessing.connection import Connection
from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
logging.basicConfig(
level="DEBUG",
force=True,
datefmt="%Y-%m-%d %H:%M:%S",
format="[%(asctime)s] %(message)s",
)
# These can all be overridden by --extra-engine-args json file
arg_map = {
"model_path": f"{model_path}",
"enable_metrics": False,
"log_level": "debug",
"log_requests": True,
"tp_size": int(tp_size_str),
# Multi-node
"dist_init_addr": dist_init_addr if dist_init_addr != "" else None,
"nnodes": int(nnodes_str),
"node_rank": int(node_rank_str),
}
json_map = {}
if extra_engine_args != "":
# extra_engine_args is a filename
try:
with open(extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.debug(f"File {extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.debug(f"Invalid JSON in {extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
server_args = ServerArgs(**arg_map)
_set_envs_and_config(server_args)
logging.debug(server_args)
ipc_path = f"ipc:///tmp/{socket_id}"
# These must match worker.rs zmq_sockets, which is the other side
port_args = PortArgs(
# we don't use this one so use anything
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
# Us -> sglang
scheduler_input_ipc_name=f"{ipc_path}_input_socket",
# sglang -> us
detokenizer_ipc_name=f"{ipc_path}_output_socket",
# The port for nccl initialization (torch.dist), which we don't use
nccl_port=9876,
)
# Rank must be globally unique across nodes
tp_rank = int(tp_rank_str)
# See nvidia-smi for GPU IDs, they run 0,1,2,etc.
# In a single-node setup this is the same as rank
gpu_id = int(gpu_id_str)
pipe_fd_int = int(pipe_fd)
writer = Connection(handle=pipe_fd_int, readable=False, writable=True)
run_scheduler_process(server_args, port_args, gpu_id, tp_rank, None, writer)
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use pyo3::{types::IntoPyDict, Python};
use std::{
env,
ffi::CString,
os::fd::RawFd,
path::{Path, PathBuf},
};
use dynamo_llm::engines::MultiNodeConfig;
const PY_START_ENGINE: &str = include_str!("sglang_inc.py");
/// Start the Python sglang engine that listens on zmq socket
/// This is called by running `nio --internal-sglang-process
/// This does not return until the subprocess exits.
pub fn run_subprocess(
// The prefix to put on the zmq socket names
socket_id: &str,
// Directory containing an HF repo with safetensors files, tokenizer, etc
model_path: &Path,
// The write half of a pipe, where sglang will signal when it's ready
notify_pipe_fd: RawFd,
// Multi node. Usually Default::default
node_config: MultiNodeConfig,
// Multi GPU. Usually Default::default
gpu_config: super::MultiGPUConfig,
// Allow passing any arguments to sglang
extra_engine_args: Option<PathBuf>,
) -> anyhow::Result<()> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
if let Ok(venv) = env::var("VIRTUAL_ENV") {
let _ = Python::with_gil(|py| crate::fix_venv(venv, py));
}
let dir = model_path.display().to_string();
let extra_engine_args_str = &extra_engine_args
.map(|p| p.display().to_string())
.unwrap_or_default();
Python::with_gil(|py| {
let locals = [
("socket_id", socket_id),
("model_path", dir.as_str()),
("pipe_fd", &notify_pipe_fd.to_string()),
// to_string because slice must all be the same type
("tp_size_str", &gpu_config.tp_size.to_string()),
("tp_rank_str", &gpu_config.tp_rank.to_string()),
("gpu_id_str", &gpu_config.gpu_id.to_string()),
("nnodes_str", &node_config.num_nodes.to_string()),
("node_rank_str", &node_config.node_rank.to_string()),
("dist_init_addr", &node_config.leader_addr),
("extra_engine_args", extra_engine_args_str),
]
.into_py_dict(py)
.unwrap();
if let Err(err) = py.run(CString::new(PY_START_ENGINE)?.as_ref(), None, Some(&locals)) {
anyhow::bail!("sglang engine run error: {err}");
}
tracing::info!("sglang subprocess exit");
Ok(())
})
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::HashMap,
env, fmt,
os::fd::{FromRawFd as _, RawFd},
path::{Path, PathBuf},
process::Stdio,
sync::Arc,
time::Duration,
vec::IntoIter,
};
use anyhow::Context as _;
use async_zmq::{SinkExt, StreamExt};
use libc::c_int;
use pyo3::{
exceptions::PyTypeError,
prelude::*,
types::{IntoPyDict, PyBytes, PyString},
};
use regex::Regex;
use tokio::sync::mpsc::Sender;
use tokio::{io::AsyncBufReadExt, sync::mpsc::error::SendError};
use tokio::{io::AsyncReadExt as _, task::JoinHandle};
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::runtime::CancellationToken;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::protocols::common::llm_backend::LLMEngineOutput;
use dynamo_llm::protocols::common::preprocessor::PreprocessedRequest;
use dynamo_llm::protocols::common::FinishReason;
use dynamo_llm::protocols::TokenIdType;
use crate::MultiGPUConfig;
/// Wait this long for the sglang sub-process to stop after we send it a KILL
const SGLANG_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
/// Match sglang python log entries, e.g "[2025-01-30 11:23:16] Some text we want"
const SGLANG_LOG_RE: &str =
r"(?<timestamp>\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\] )?(?<message>.*)";
/// Identify sglang log entries with this prefix
const LOG_PREFIX: &str = "SGLANG";
/// Part of what sglang sends us over it's pipe when it's ready
const READY_BYTES: [u8; 5] = [b'r', b'e', b'a', b'd', b'y'];
type RequestID = String;
pub struct SgLangWorker {
/// How we receive work requests
tx: Sender<WorkRequest>,
/// Handle of the task that reads from `tx` and forwards those requests over zmq to vllm
_input_loop: JoinHandle<()>,
/// Handle of the task that reads sglang's responses from zmq and dispatches them to the correct
/// active request.
_output_loop: JoinHandle<()>,
/// Handle of the vllm background process
sglang: Option<JoinHandle<()>>,
// We don't need to hold on to this, it's already shared between input_loop and output_loop
// But later we'll probably want stats - how many active requests etc, so keep it here
_active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
// Need to keep this alive
// TODO: With async_zmq we possibly don't need this at all
#[allow(dead_code)]
zmq_context: async_zmq::Context,
}
/// How we get asked to do some work. These get unpacked and forwarded to vllm.
pub struct WorkRequest {
pub request: PreprocessedRequest,
pub request_id: RequestID,
pub response_channel: Sender<Annotated<LLMEngineOutput>>,
}
/// A request currently being process by vllm
struct ActiveRequest {
tx: Sender<Annotated<LLMEngineOutput>>,
num_output_tokens_so_far: Option<i32>,
}
/// Python imports
struct Imports {
pickle_module: PyObject,
sampling_params_type: PyObject,
rpc_type: PyObject,
}
/// All the zmq sockets we used. This object only used to passing them around to avoid large
/// tuples.
struct Sockets {
#[allow(dead_code)]
context: async_zmq::Context, // we have to keep this alive
// Requests from us to the sglang engine
// scheduler_input_ipc_name,
input: async_zmq::Push<IntoIter<Vec<u8>>, Vec<u8>>,
// Responses from the sglang engine back to us
// tokenizer_ipc_name
output: async_zmq::Pull,
}
/// What sglang sends us.
#[allow(dead_code)]
#[derive(FromPyObject, Debug)]
pub struct BatchTokenIDOut {
// The request id
rids: Vec<String>,
// The finish reason
// sglang implements finish reason as subclasses of BaseFinishReason
// e.g. `class FINISH_LENGTH(BaseFinishReason):` and lots of others
finished_reasons: Vec<Option<SgLangFinishReason>>,
// For incremental decoding
// The version id to sync decode status with in detokenizer_manager
vids: Vec<i32>,
decoded_texts: Vec<String>,
decode_ids: Vec<Vec<u32>>,
read_offsets: Vec<i32>,
// Only used when `--skip-tokenizer-init` is on
output_ids: Option<Vec<i32>>,
// Detokenization configs
skip_special_tokens: Vec<bool>,
spaces_between_special_tokens: Vec<bool>,
no_stop_trim: Vec<bool>,
// Token counts
prompt_tokens: Vec<i32>,
completion_tokens: Vec<i32>,
cached_tokens: Vec<i32>,
spec_verify_ct: Vec<i32>,
// Logprobs
input_token_logprobs_val: Option<Vec<f64>>,
input_token_logprobs_idx: Option<Vec<i32>>,
output_token_logprobs_val: Option<Vec<f64>>,
output_token_logprobs_idx: Option<Vec<i32>>,
// These in Python are all `List[List]`, so guess
input_top_logprobs_val: Option<Vec<Vec<f64>>>,
input_top_logprobs_idx: Option<Vec<Vec<i32>>>,
output_top_logprobs_val: Option<Vec<Vec<f64>>>,
output_top_logprobs_idx: Option<Vec<Vec<i32>>>,
}
#[derive(Debug, Copy, Clone)]
enum SgLangFinishReason {
Matched,
Length,
Abort,
}
impl fmt::Display for SgLangFinishReason {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
SgLangFinishReason::Matched => write!(f, "Finished due to a successful match"),
SgLangFinishReason::Length => {
write!(f, "Finished due to reaching the specified length")
}
SgLangFinishReason::Abort => write!(f, "Operation was aborted"),
}
}
}
impl<'py> FromPyObject<'py> for SgLangFinishReason {
fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
// The object we have is a subclass of sglang's BaseFinishReason, one subclass
// per finish reason. I don't know how to identify the class, but if we force
// it to a string I _think_ it ends up calling `json_str` in the subclass.
// Also the string uses single quotes in the JSON, I don't know why.
let json_str = obj.str()?.to_string().replace("'", "\"");
let as_map: HashMap<String, serde_json::Value> =
serde_json::from_str(&json_str).map_err(|err| {
tracing::error!("SgLangFinishReason JSON convert err: {err}. JSON: {json_str}");
PyTypeError::new_err(format!("serde_json err: {err}. JSON: {json_str}"))
})?;
let Some(type_serde) = as_map.get("type") else {
return Err(PyTypeError::new_err("Finish reason missing 'type' JSON field. See sglang's schedule_batch.py BaseFinishReason"));
};
let Some(type_str) = type_serde.as_str() else {
return Err(PyTypeError::new_err("Finish reason 'type' JSON field is not a string. See sglang's schedule_batch.py BaseFinishReason"));
};
match type_str {
"stop" => Ok(SgLangFinishReason::Matched),
"length" => Ok(SgLangFinishReason::Length),
"abort" => Ok(SgLangFinishReason::Abort),
x => {
tracing::warn!("Unknown sglang BaseFinishReason type '{x}'. Using Abort instead.");
Ok(SgLangFinishReason::Abort)
}
}
}
}
impl From<SgLangFinishReason> for FinishReason {
fn from(sfr: SgLangFinishReason) -> Self {
use SgLangFinishReason::*;
match sfr {
Matched => FinishReason::Stop,
Length => FinishReason::Length,
Abort => FinishReason::Cancelled, // or FinishReason::Error ?
}
}
}
/* What we send to sglang
class TokenizedGenerateReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
# The image inputs
image_inputs: dict
# The sampling parameters
sampling_params: SamplingParams
# Whether to return the logprobs
return_logprob: bool
# If return logprobs, the start location in the prompt for returning logprobs.
logprob_start_len: int
# If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: int
# Whether to stream output
stream: bool
# LoRA related
lora_path: Optional[str] = None # None means just use the base model
# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Session info for continual prompting
session_params: Optional[SessionParams] = None
# Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[str] = None
class SamplingParams:
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
min_new_tokens: int = 0,
spaces_between_special_tokens: bool = True,
n: int = 1,
json_schema: Optional[str] = None,
regex: Optional[str] = None,
ebnf: Optional[str] = None,
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
custom_params: Optional[Dict[str, Any]] = None,
*/
/// Main entry point
pub async fn start(
cancel_token: CancellationToken,
sock_code: &str,
model_path: &Path,
node_conf: MultiNodeConfig,
tp_size: u32,
base_gpu_id: u32,
extra_engine_args: Option<PathBuf>,
) -> anyhow::Result<SgLangWorker> {
pyo3::prepare_freethreaded_python();
if let Ok(venv) = env::var("VIRTUAL_ENV") {
let _ = Python::with_gil(|py| crate::fix_venv(venv, py));
}
let Sockets {
context,
input,
output,
} = zmq_sockets(sock_code)?;
let py_imports = Arc::new(python_imports());
if tp_size < node_conf.num_nodes {
anyhow::bail!(
"Need at least as many GPUs as nodes. Pass --tensor-parallel-size >= --num-nodes."
);
}
let tp_size_per_node = tp_size / node_conf.num_nodes;
let tp_rank_start = tp_size_per_node * node_conf.node_rank;
let tp_rank_end = tp_size_per_node * (node_conf.node_rank + 1);
// Start all the sglang workers. They communicate amongst themselves using torch distributed
// and nccl. They must all start at once.
let mut sglang_join_handle = None;
let mut process_group = Vec::with_capacity(tp_size as usize);
for tp_rank in tp_rank_start..tp_rank_end {
let gpu_id = base_gpu_id + tp_rank % tp_size_per_node;
let gpu_conf = MultiGPUConfig {
tp_size,
tp_rank,
gpu_id,
};
let (sglang_process, ready_fd) = start_sglang(
model_path,
node_conf.clone(),
gpu_conf,
extra_engine_args.clone(),
)
.await?;
process_group.push((tp_rank, ready_fd));
let watcher_join_handle = watch_sglang(cancel_token.clone(), sglang_process);
// TODO: Do we want to hold on to this?
// Do we need it for the other sub-processes?
if sglang_join_handle.is_none() {
sglang_join_handle = Some(watcher_join_handle);
}
}
for (tp_rank, ready_fd) in process_group.into_iter() {
wait_for_sglang(tp_rank, ready_fd).await?;
}
let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let (tx, rx) = tokio::sync::mpsc::channel(8);
let input_loop_handle = {
let cancel_token = cancel_token.clone();
let py_imports = py_imports.clone();
let active_requests = active_requests.clone();
tokio::spawn(input_loop(
cancel_token,
py_imports,
input,
active_requests,
rx,
))
};
let output_loop_handle = {
let cancel_token = cancel_token.clone();
let py_imports = py_imports.clone();
let active_requests = active_requests.clone();
tokio::spawn(output_loop(
cancel_token,
py_imports,
output,
active_requests,
))
};
Ok(SgLangWorker {
tx,
zmq_context: context,
_input_loop: input_loop_handle,
_output_loop: output_loop_handle,
sglang: sglang_join_handle,
_active_requests: active_requests,
})
}
/// Import all the python packages we'll need.
fn python_imports() -> Imports {
Python::with_gil(|py| {
let pickle_module: PyObject = match py.import("pickle") {
Ok(m) => m.into(),
Err(err) => {
// There is no sglang without python
panic!("Failed to import python 'pickle' module. Is Python installed? {err}");
}
};
// This one is a sanity check
if let Err(err) = py.import("sglang") {
panic!("Failed to import python 'sglang' module. Are we running in the correct venv? {err}");
};
let mod_iostruct: PyObject = match py.import("sglang.srt.managers.io_struct") {
Ok(m) => m.into(),
Err(err) => {
panic!("Failed to import sglang.srt.managers.io_struct. Did sglang change? {err}");
}
};
let rpc_type = mod_iostruct
.getattr(py, "TokenizedGenerateReqInput")
.unwrap();
let mod_sampling: PyObject = match py.import("sglang.srt.sampling.sampling_params") {
Ok(m) => m.into(),
Err(err) => {
panic!(
"Failed to import sglang.srt.sampling.sampling_params. Did sglang change? {err}"
);
}
};
let sampling_params_type: PyObject = mod_sampling.getattr(py, "SamplingParams").unwrap();
Imports {
pickle_module,
sampling_params_type,
rpc_type,
}
})
}
/// Create all the zmq sockets we're going to use.
fn zmq_sockets(sock_code: &str) -> anyhow::Result<Sockets> {
let zmq_context = async_zmq::Context::new();
// Scheduler (rank 0) to receive inputs from us
let input = async_zmq::push(&format!("ipc:///tmp/{sock_code}_input_socket"))?
.with_context(&zmq_context)
.bind()?;
// Use to receive replies from scheduler.
let output = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_output_socket"))?
.with_context(&zmq_context)
.bind()?;
Ok(Sockets {
context: zmq_context,
input,
output,
})
}
/// Start the python sub-process and wait for it to be ready
async fn start_sglang(
model_path: &Path,
node_conf: MultiNodeConfig,
gpu_conf: MultiGPUConfig,
extra_engine_args: Option<PathBuf>,
) -> anyhow::Result<(tokio::process::Child, RawFd)> {
// This pipe is how sglang tells us it's ready
let mut pipe_fds: [libc::c_int; 2] = [-1, -1];
unsafe {
// Seems to be OK without libc::O_NONBLOCK
let err = libc::pipe(pipe_fds.as_mut_ptr() as *mut c_int);
if err != 0 {
anyhow::bail!("libc::pipe error {err}");
}
}
let sglang_says_hello = pipe_fds[1] as RawFd;
let tp_rank = gpu_conf.tp_rank;
let gpu_id = gpu_conf.gpu_id;
let mut args = vec![
format!("--internal-sglang-process={sglang_says_hello},{tp_rank},{gpu_id}"),
format!("--model-path={}", model_path.display()),
format!("--tensor-parallel-size={}", gpu_conf.tp_size),
format!("--num-nodes={}", node_conf.num_nodes),
format!("--node-rank={}", node_conf.node_rank),
];
if let Some(extra_engine_args) = extra_engine_args {
args.push(format!(
"--extra-engine-args={}",
extra_engine_args.display()
));
};
if node_conf.num_nodes > 1 {
if node_conf.leader_addr.is_empty() {
anyhow::bail!("Missing --leader-addr for multi-node");
}
args.push(format!("--leader-addr={}", node_conf.leader_addr));
}
let self_path = std::env::current_exe()?;
let mut proc = tokio::process::Command::new(self_path)
.args(args)
.kill_on_drop(false)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let stdout = tokio::io::BufReader::new(proc.stdout.take().unwrap());
let stderr = tokio::io::BufReader::new(proc.stderr.take().unwrap());
// Log sglang's stdout
// sglang has (almost?) no output on stdout
tokio::spawn(async move {
let mut lines = stdout.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::debug!("{LOG_PREFIX}{tp_rank} {line}");
}
});
// Log sglang's stderr
tokio::spawn(async move {
// Remove extra date/time entries from stderr, and print with prefix
let line_re = Regex::new(SGLANG_LOG_RE).unwrap();
let mut lines = stderr.lines();
while let Ok(Some(line)) = lines.next_line().await {
if let Some(caps) = line_re.captures(&line) {
match caps.name("timestamp") {
Some(_) => {
// Skip Python's date/time. Should be normal log
tracing::debug!("{LOG_PREFIX}{tp_rank} {}", &caps["message"]);
}
None => {
// No date/time. Usually errors
tracing::warn!("{LOG_PREFIX}{tp_rank} {line}");
}
}
}
}
});
let ready_fd = pipe_fds[0] as RawFd;
Ok((proc, ready_fd))
}
async fn wait_for_sglang(tp_rank: u32, pipe_fd: RawFd) -> anyhow::Result<()> {
tracing::info!("Waiting for sglang{tp_rank} to signal that it's ready");
let mut sglang_ready = unsafe { tokio::fs::File::from_raw_fd(pipe_fd) };
let mut buf = [0u8; 128]; // Some pickled JSON, about 90 bytes
let len_read = sglang_ready
.read(&mut buf)
.await
.with_context(|| format!("Failed reading from Rust side of sglang pipe, fd {pipe_fd}",))?;
let received_bytes = &buf[..len_read];
/* received_bytes is pickled JSON:
{
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
}
We could unpickle it, but this is faster.
*/
if !received_bytes
.windows(READY_BYTES.len())
.any(|candidate| candidate == READY_BYTES)
{
anyhow::bail!("Expected sglang pipe to signal ready, but did not contain 'ready' bytes");
}
// TODO: warm up the engine
tracing::info!("sglang{tp_rank} is ready");
Ok(())
}
// Stop the sglang process when we stop, and prevent it going zombie.
fn watch_sglang(
cancel_token: CancellationToken,
mut sglang_process: tokio::process::Child,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
cancel_token.cancelled().await;
tokio::select! {
_ = sglang_process.wait() => {
return;
},
_ = tokio::time::sleep(SGLANG_STOP_TIMEOUT) => { }
}
if let Err(err) = sglang_process.start_kill() {
tracing::error!("Failing killing sglang subprocess: {err}");
return;
}
tokio::select! {
_ = sglang_process.wait() => { },
_ = tokio::time::sleep(SGLANG_STOP_TIMEOUT) => {
tracing::warn!("Timeout waiting for sglang sub-process to stop after kill");
}
}
})
}
async fn input_loop(
cancel_token: CancellationToken,
py_imports: Arc<Imports>,
mut input_socket: async_zmq::Push<IntoIter<Vec<u8>>, Vec<u8>>,
active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
mut rx: tokio::sync::mpsc::Receiver<WorkRequest>,
) {
loop {
let work_request = tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("SgLangWorker.main_loop exit");
break;
}
req = rx.recv() => {
match req {
Some(req) => req,
None => {
tracing::trace!("SgLangWorker input_loop socket closed");
break;
}
}
}
};
let request_id = work_request.request_id;
let token_ids = work_request.request.token_ids.clone();
let temperature: f64 = work_request
.request
.sampling_options
.temperature
.unwrap_or(0.0)
.into();
tracing::trace!("Received work request: {request_id}");
// Parts that don't change
let (py_request_id, sampling_params) = Python::with_gil(|py| {
let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into();
let mut sp_kwargs = vec![("temperature", py_temp)];
if let Some(max_tokens) = work_request.request.stop_conditions.max_tokens {
let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into();
// sglang defaults this to 128
sp_kwargs.push(("max_new_tokens", py_max_tokens));
}
let sp_kwargs = sp_kwargs.into_py_dict(py).unwrap();
let sampling_params = py_imports
.sampling_params_type
.call(py, (), Some(&sp_kwargs))
.unwrap();
sampling_params
.getattr(py, "normalize")
.unwrap()
.call1(py, (py.None(),))
.unwrap();
let py_request_id: PyObject = PyString::new(py, &request_id).into();
(py_request_id, sampling_params)
});
let pickled_req: Vec<u8> = Python::with_gil(|py| {
let input_text: PyObject = "".into_pyobject(py).unwrap().into();
let input_ids: PyObject = token_ids.into_pyobject(py).unwrap().into();
let image_inputs: PyObject = py.None();
let return_logprob: PyObject = false.into_pyobject(py).unwrap().to_owned().into();
let logprob_start_len: PyObject = 0u32.into_pyobject(py).unwrap().into();
let top_logprobs_num: PyObject = 0u32.into_pyobject(py).unwrap().into();
let stream: PyObject = true.into_pyobject(py).unwrap().to_owned().into();
let rpc_pos_args = (
py_request_id,
input_text,
input_ids,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
stream,
);
//let rpc_kwargs = [].into_py_dict(py).unwrap();
let req = py_imports
.rpc_type
.call(py, rpc_pos_args, None) // Some(&rpc_kwargs))
.unwrap();
let pickle_dumps = py_imports.pickle_module.getattr(py, "dumps").unwrap();
pickle_dumps.call1(py, (req,)).unwrap().extract(py).unwrap()
});
let new_active_request = ActiveRequest {
tx: work_request.response_channel,
num_output_tokens_so_far: None,
};
active_requests
.lock()
.await
.insert(request_id, new_active_request);
if let Err(err) = input_socket.send(pickled_req.into()).await {
tracing::error!("Error sending new request to sglang over zmq: {err}");
}
}
}
/// Read from sglang's output zmq socket, find which request it is for and forward over that channel.
async fn output_loop(
cancel_token: CancellationToken,
py_imports: Arc<Imports>,
mut output_socket: async_zmq::Pull,
active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
) {
loop {
let maybe_bb = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_bb = output_socket.next() => {
maybe_bb
}
};
let mut bb = match maybe_bb {
Some(Ok(b)) => b,
Some(Err(err)) => {
tracing::error!("Error reading from sglang zmq output: {err}");
continue; // hope live eternal
}
None => {
tracing::debug!("zmq output socket closed");
break;
}
};
let frame = bb.remove(0);
let req_out: BatchTokenIDOut = Python::with_gil(|py| {
let pickle_loads = py_imports.pickle_module.getattr(py, "loads").unwrap();
let frame_bytes = PyBytes::new(py, &frame);
let pyobj = pickle_loads.call1(py, (frame_bytes,)).unwrap();
pyobj.extract(py).unwrap()
});
tracing::trace!(?req_out, "from sglang");
let mut remove_after = vec![];
for (idx, req_id) in req_out.rids.into_iter().enumerate() {
let next_total_toks = req_out.decode_ids[idx].len() as i32;
match active_requests.lock().await.get_mut(&req_id) {
Some(active) => {
let previous_total_toks = active
.num_output_tokens_so_far
.unwrap_or(req_out.read_offsets[idx])
as usize;
let sglang_finish_reason = req_out.finished_reasons[idx];
let token_ids: Vec<TokenIdType> = if sglang_finish_reason.is_none() {
req_out.decode_ids[idx][previous_total_toks..].into()
} else {
tracing::trace!(
req_id,
?sglang_finish_reason,
"finished with finish reason"
);
// Request is over, sglang says so.
// The last token is the eos_token, don't forward it
remove_after.push(req_id.clone());
vec![]
};
let out = LLMEngineOutput {
token_ids,
tokens: None,
text: None,
cum_log_probs: None,
log_probs: None,
finish_reason: sglang_finish_reason.map(|x| x.into()),
};
active.num_output_tokens_so_far = Some(next_total_toks);
let _ = active.tx.send(Annotated::from_data(out)).await;
}
None => {
// sglang sends the finish response twice, I don't know why
// so only log if it isn't a finished request
if req_out.finished_reasons[idx].is_none() {
tracing::warn!(req_id, "Missing active request");
}
}
}
}
for req_id in remove_after {
let _ = active_requests.lock().await.remove(&req_id);
}
}
}
impl SgLangWorker {
/// Send a request to sglang
pub async fn enqueue_request(&self, r: WorkRequest) -> Result<(), SendError<WorkRequest>> {
self.tx.send(r).await
}
/// Get the sglang sub-process handle, so we can await it and prevent it going zombie.
pub fn take_sglang_handle(&mut self) -> JoinHandle<()> {
self.sglang.take().unwrap()
}
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name = "dynamo-engine-vllm0_7"
version.workspace = true
edition.workspace = true
description.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
keywords.workspace = true
[dependencies]
dynamo-runtime = { workspace = true }
dynamo-llm = { workspace = true }
anyhow = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
async_zmq = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
async-openai = "0.27.2"
pyo3 = { version = "0.23.3", default-features = false, features = [
"macros",
"experimental-async",
"experimental-inspect",
"py-clone",
] }
regex = "1"
serde-pickle = "1.2.0"
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_stream::stream;
use async_trait::async_trait;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::kv_router::publisher::KvMetricsPublisher;
use dynamo_llm::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::runtime::CancellationToken;
use crate::worker;
pub struct VllmEngine {
cancel_token: CancellationToken,
worker: worker::VllmWorker,
}
impl VllmEngine {
pub async fn new(
cancel_token: CancellationToken,
sock_code: &str,
model_path: &Path,
node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>,
kv_metrics_publisher: Option<Arc<KvMetricsPublisher>>,
) -> anyhow::Result<Self> {
let w = worker::start(
cancel_token.clone(),
sock_code,
model_path,
node_conf,
tensor_parallel_size,
extra_engine_args,
kv_metrics_publisher,
)
.await?;
let engine = VllmEngine {
cancel_token,
worker: w,
};
Ok(engine)
}
pub fn take_vllm_worker_handle(&mut self) -> tokio::task::JoinHandle<()> {
self.worker.take_vllm_handle()
}
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for VllmEngine
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let (request, context) = request.into_parts();
let ctx = context.context();
let request_id = ctx.id().to_string();
let (resp_tx, mut resp_rx) = tokio::sync::mpsc::channel(128);
let work_req = worker::WorkRequest {
request_id: context.id().to_string(),
request,
response_channel: resp_tx,
};
self.worker.enqueue_request(work_req).await?;
let cancel_token = self.cancel_token.clone();
let output = stream! {
loop {
let maybe_resp = tokio::select!{
_ = cancel_token.cancelled() => {
break;
}
maybe_resp = resp_rx.recv() => {
maybe_resp
}
};
match maybe_resp {
Some(out) => {
yield out;
},
None => {
tracing::trace!(request_id, "generate: response channel closed");
break;
}
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::future::Future;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use pyo3::prelude::*;
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::CancellationToken;
use dynamo_llm::backend::ExecutionContext;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::kv_router::publisher::KvMetricsPublisher;
mod engine;
use engine::VllmEngine;
mod ray;
use ray::Ray;
mod subprocess;
pub use subprocess::run_subprocess;
mod worker;
pub async fn make_leader_engine(
cancel_token: CancellationToken,
// Full path to the model, either a GGUF file or an HF repo dir
model_path: &Path,
// Unique string to name zmq sockets
sock_code: &str,
// Multi node settings
node_conf: MultiNodeConfig,
// How many GPUs to use
tensor_parallel_size: u32,
// Path to extra engine args file
extra_engine_args: Option<PathBuf>,
// When using our vllm fork, this is how we publish it's KV metrics for the KV router
kv_metrics_publisher: Option<Arc<KvMetricsPublisher>>,
) -> pipeline_error::Result<(ExecutionContext, impl Future<Output = ()>)> {
let ray_obj = if node_conf.num_nodes > 1 {
let r = ray::start_leader(node_conf.leader_addr.parse()?)?;
tracing::info!("Leader waiting for {} total nodes", node_conf.num_nodes);
r.wait_for(cancel_token.clone(), node_conf.num_nodes)
.await?;
tracing::info!("All nodes registered");
Some(r)
} else {
None
};
let mut engine = VllmEngine::new(
cancel_token,
sock_code,
model_path,
node_conf,
tensor_parallel_size,
extra_engine_args,
kv_metrics_publisher,
)
.await?;
let vllm_process = engine.take_vllm_worker_handle();
let vllm_future = async move {
if let Err(err) = vllm_process.await {
tracing::error!("Failed stopping vllm process: {err:#}");
}
if let Some(r) = ray_obj {
if let Err(err) = r.stop().await {
tracing::error!("Failed stopping ray: {err:#}");
}
}
};
let engine: ExecutionContext = Arc::new(engine);
Ok((engine, vllm_future))
}
pub async fn start_follower(
cancel_token: CancellationToken,
node_conf: MultiNodeConfig,
) -> pipeline_error::Result<StopFuture> {
let r = ray::start_follower(node_conf.leader_addr.parse()?)?;
tracing::info!("Follower waiting for {} total nodes", node_conf.num_nodes);
r.wait_for(cancel_token, node_conf.num_nodes).await?;
tracing::info!("All nodes registered");
Ok(StopFuture {
state: Some(StopFutureState::New(r)),
})
}
pub struct StopFuture {
state: Option<StopFutureState>,
}
enum StopFutureState {
New(Ray),
Running(Pin<Box<dyn Future<Output = ()> + Send>>),
}
impl Future for StopFuture {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let state = match self.state.take() {
None => return Poll::Ready(()),
Some(state) => state,
};
match state {
StopFutureState::New(obj) => {
// Convert object to a stop future
let future = Box::pin(async move {
if let Err(err) = obj.stop().await {
tracing::error!("Failed calling 'ray stop': {err:#}");
}
});
self.state = Some(StopFutureState::Running(future));
// Recurse to poll the new future immediately
self.poll(cx)
}
StopFutureState::Running(mut future) => {
// Poll the stop future
match future.as_mut().poll(cx) {
Poll::Ready(()) => {
// Done, leave state as None
Poll::Ready(())
}
Poll::Pending => {
// Not ready yet, preserve the future
self.state = Some(StopFutureState::Running(future));
Poll::Pending
}
}
}
}
}
}
#[cfg(target_os = "macos")]
fn fix_venv(venv: String, py: Python<'_>) -> anyhow::Result<()> {
let version_info = py.version_info();
let sys: PyObject = py.import("sys")?.into();
let sys_path = sys.getattr(py, "path")?;
let venv_path = format!(
"{venv}/lib/python{}.{}/site-packages",
version_info.major, version_info.minor
);
// TODO: This should go _before_ the site-packages
sys_path.call_method1(py, "append", (venv_path,))?;
Ok(())
}
#[cfg(not(target_os = "macos"))]
fn fix_venv(_venv: String, _py: Python<'_>) -> anyhow::Result<()> {
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use regex::Regex;
use std::io::{BufRead, BufReader};
use std::net::SocketAddrV4;
use std::process::{Command, Stdio};
use std::time::Duration;
use thiserror::Error;
use tokio::io::AsyncBufReadExt;
use tokio::select;
use tokio::time;
use dynamo_runtime::CancellationToken;
/// Default is 16 seconds, we make it a bit shorter
const RAY_STOP_TIMEOUT_SECS: u32 = 10;
/// How long to wait for all the nodes to start.
/// This is either done manually or through some orchestration system, so either way it
/// can take some time.
const RAY_WAIT_SECS: u32 = 60 * 5;
#[derive(Debug, Error)]
pub enum RayError {
#[error("Failed to execute Ray command: {0}")]
CommandExecution(#[from] std::io::Error),
#[error("Ray command failed with exit code: {0}")]
CommandFailed(i32),
#[error("Failed to parse Ray status output")]
StatusParseError,
#[error("Timeout waiting for nodes to become active")]
WaitTimeout,
#[error("Operation cancelled")]
Cancelled,
}
#[derive(Debug, PartialEq)]
pub struct RayStatus {
pub active_nodes: Vec<String>,
pub pending_nodes_count: usize,
pub recent_failures_count: usize,
}
pub struct Ray {
#[allow(dead_code)]
leader_address: SocketAddrV4,
}
pub fn start_leader(leader_address: SocketAddrV4) -> Result<Ray, RayError> {
let ip = leader_address.ip().to_string();
let port = leader_address.port().to_string();
let mut cmd = Command::new("ray");
cmd.args([
"start",
"--head",
"--disable-usage-stats",
"--log-style=record",
&format!("--node-ip-address={}", ip),
&format!("--port={}", port),
]);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
let mut child = cmd.spawn()?;
// Process stdout
if let Some(stdout) = child.stdout.take() {
let reader = BufReader::new(stdout);
for line in reader.lines().map_while(Result::ok) {
tracing::info!("RAY: {line}");
}
}
// Process stderr
if let Some(stderr) = child.stderr.take() {
let reader = BufReader::new(stderr);
for line in reader.lines().map_while(Result::ok) {
tracing::info!("RAY: {line}");
}
}
let status = child.wait()?;
if !status.success() {
return Err(RayError::CommandFailed(status.code().unwrap_or(-1)));
}
Ok(Ray { leader_address })
}
pub fn start_follower(leader_address: SocketAddrV4) -> Result<Ray, RayError> {
let address = leader_address.to_string();
let mut cmd = Command::new("ray");
cmd.args(["start", &format!("--address={address}")]);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
let mut child = cmd.spawn()?;
// Process stdout
if let Some(stdout) = child.stdout.take() {
let reader = BufReader::new(stdout);
for line in reader.lines().map_while(Result::ok) {
tracing::info!("RAY: {line}");
}
}
// Process stderr
if let Some(stderr) = child.stderr.take() {
let reader = BufReader::new(stderr);
for line in reader.lines().map_while(Result::ok) {
tracing::info!("RAY: {line}");
}
}
let status = child.wait()?;
if !status.success() {
return Err(RayError::CommandFailed(status.code().unwrap_or(-1)));
}
Ok(Ray { leader_address })
}
impl Ray {
pub fn status(&self) -> Result<RayStatus, RayError> {
let output = Command::new("ray").arg("status").output()?;
if !output.status.success() {
return Err(RayError::CommandFailed(output.status.code().unwrap_or(-1)));
}
let output_str = String::from_utf8_lossy(&output.stdout);
parse_ray_status(&output_str).ok_or(RayError::StatusParseError)
}
pub async fn wait_for(
&self,
cancel_token: CancellationToken,
num_nodes: u32,
) -> Result<(), RayError> {
let timeout = time::sleep(Duration::from_secs(RAY_WAIT_SECS as u64));
select! {
_ = cancel_token.cancelled() => {
Err(RayError::Cancelled)
}
_ = timeout => {
Err(RayError::WaitTimeout)
}
result = self.wait_for_nodes(num_nodes) => {
result
}
}
}
async fn wait_for_nodes(&self, num_nodes: u32) -> Result<(), RayError> {
loop {
let status = self.status()?;
if status.active_nodes.len() as u32 == num_nodes {
return Ok(());
}
time::sleep(Duration::from_millis(100)).await;
}
}
pub async fn stop(&self) -> Result<(), RayError> {
let mut cmd = tokio::process::Command::new("ray");
cmd.args([
"stop",
&format!("--grace-period={RAY_STOP_TIMEOUT_SECS}"),
"--log-style=record",
]);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
let mut child = cmd.spawn()?;
// Process stdout
if let Some(stdout) = child.stdout.take() {
let reader = tokio::io::BufReader::new(stdout);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::info!("RAY: {line}");
}
}
// Process stderr
if let Some(stderr) = child.stderr.take() {
let reader = tokio::io::BufReader::new(stderr);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::info!("RAY: {line}");
}
}
let status = child.wait().await?;
if !status.success() {
return Err(RayError::CommandFailed(status.code().unwrap_or(-1)));
}
Ok(())
}
}
/// Parse the output of "ray status" command into a RayStatus struct
fn parse_ray_status(output: &str) -> Option<RayStatus> {
let mut active_nodes = Vec::new();
let mut pending_nodes_count = 0;
let mut recent_failures_count = 0;
// Flags to track which section we're in
let mut in_active_section = false;
let mut in_pending_section = false;
let mut in_failures_section = false;
// Regex to match node IDs
let node_regex = Regex::new(r"(\d+)\s+(node_[a-f0-9]+)").unwrap();
let num_regex = Regex::new(r"(\d+)").unwrap();
for line in output.lines() {
let trimmed = line.trim();
if trimmed == "Active:" {
in_active_section = true;
in_pending_section = false;
in_failures_section = false;
continue;
} else if trimmed == "Pending:" {
in_active_section = false;
in_pending_section = true;
in_failures_section = false;
continue;
} else if trimmed == "Recent failures:" {
in_active_section = false;
in_pending_section = false;
in_failures_section = true;
continue;
} else if trimmed.starts_with("Resources") {
// We've reached the end of the node status section
break;
}
if in_active_section {
if let Some(captures) = node_regex.captures(trimmed) {
if let Some(node_id) = captures.get(2) {
active_nodes.push(node_id.as_str().to_string());
}
}
} else if in_pending_section && trimmed != "(no pending nodes)" {
// Count pending nodes
if let Some(captures) = num_regex.captures(trimmed) {
if let Some(count) = captures.get(1) {
if let Ok(count) = count.as_str().parse::<usize>() {
pending_nodes_count += count;
}
}
}
} else if in_failures_section && trimmed != "(no failures)" {
// Count failures
if let Some(captures) = num_regex.captures(trimmed) {
if let Some(count) = captures.get(1) {
if let Ok(count) = count.as_str().parse::<usize>() {
recent_failures_count += count;
}
}
}
}
}
Some(RayStatus {
active_nodes,
pending_nodes_count,
recent_failures_count,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_ray_status() {
let sample_output = r#"======== Autoscaler status: 2025-03-04 13:13:59.104771 ========
Node status
---------------------------------------------------------------
Active:
1 node_b09a7440bd0987680f97c35206b2475251907d0c928fdd0f52b1b38f
1 node_035ea3b640e13f3603d3debd97de8c569ed8c8b10e19ce00ea4fd070
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/256.0 CPU
0.0/16.0 GPU
0B/1.58TiB memory
0B/372.53GiB object_store_memory
Demands:
(no resource demands)
"#;
let expected = RayStatus {
active_nodes: vec![
"node_b09a7440bd0987680f97c35206b2475251907d0c928fdd0f52b1b38f".to_string(),
"node_035ea3b640e13f3603d3debd97de8c569ed8c8b10e19ce00ea4fd070".to_string(),
],
pending_nodes_count: 0,
recent_failures_count: 0,
};
let result = parse_ray_status(sample_output);
assert!(result.is_some());
assert_eq!(result.unwrap(), expected);
}
/// Test with pending nodes and failures
#[test]
fn test_parse_ray_status_with_failing() {
let sample_output_with_pending = r#"======== Autoscaler status: 2025-03-04 13:13:59.104771 ========
Node status
---------------------------------------------------------------
Active:
1 node_b09a7440bd0987680f97c35206b2475251907d0c928fdd0f52b1b38f
Pending:
2 node_pending_1
3 node_pending_2
Recent failures:
1 node_failure_1
4 node_failure_2
Resources
---------------------------------------------------------------
Usage:
0.0/256.0 CPU
"#;
let expected_with_pending = RayStatus {
active_nodes: vec![
"node_b09a7440bd0987680f97c35206b2475251907d0c928fdd0f52b1b38f".to_string(),
],
pending_nodes_count: 5, // 2 + 3
recent_failures_count: 5, // 1 + 4
};
let result = parse_ray_status(sample_output_with_pending);
assert!(result.is_some());
assert_eq!(result.unwrap(), expected_with_pending);
}
/// Test with empty output
#[test]
fn test_parse_ray_status_empty() {
let empty_output = "";
let result = parse_ray_status(empty_output);
assert!(result.is_some());
assert_eq!(result.unwrap().active_nodes.len(), 0);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use pyo3::{types::IntoPyDict, Python};
use std::env;
use std::ffi::CString;
use std::path::{Path, PathBuf};
use dynamo_llm::engines::MultiNodeConfig;
const PY_START_ENGINE: &str = include_str!("vllm_inc.py");
/// Start the Python vllm engine that listens on zmq socket
/// This is called by running `<bin> --internal-vllm-process
/// This does not return until vllm exits.
pub fn run_subprocess(
socket_id: &str,
model_path: &Path,
node_config: MultiNodeConfig,
tp_size: u32,
extra_engine_args: Option<PathBuf>,
with_kv_routing: bool,
) -> anyhow::Result<()> {
if with_kv_routing {
set_kv_routing_vars()?;
}
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
if let Ok(venv) = env::var("VIRTUAL_ENV") {
let _ = Python::with_gil(|py| crate::fix_venv(venv, py));
}
let model_path_str = model_path.display().to_string();
let extra_engine_args_str = &extra_engine_args
.map(|p| p.display().to_string())
.unwrap_or_default();
Python::with_gil(|py| {
let locals = [
("socket_id", socket_id),
("model_path", model_path_str.as_str()),
("tp_size_str", &tp_size.to_string()),
("nnodes_str", &node_config.num_nodes.to_string()),
("extra_engine_args", extra_engine_args_str),
("enable_prefix_caching", &with_kv_routing.to_string()),
]
.into_py_dict(py)
.unwrap();
if let Err(err) = py.run(CString::new(PY_START_ENGINE)?.as_ref(), None, Some(&locals)) {
anyhow::bail!("vllm engine run error: {err}");
}
tracing::info!("vllm subprocess exit");
Ok(())
})
}
// These environment variables trigger our vllm patch to emit KV routing events
fn set_kv_routing_vars() -> anyhow::Result<()> {
let exe = env::current_exe()?;
let exe_dir = exe
.parent()
.ok_or(anyhow::anyhow!("Current binary has no directory"))?;
let mut lib = PathBuf::from(exe_dir);
lib.set_file_name("libdynamo_llm_capi.so");
let vars = [
// Path to the C API Library
("VLLM_KV_CAPI_PATH", lib.display().to_string()),
// Identifiers to publish KV related information
("VLLM_KV_NAMESPACE", "dynamo".to_string()),
("VLLM_KV_COMPONENT", "vllm".to_string()),
// Worker ID used for identifying workers in distributed settings
("VLLM_WORKER_ID", "0".to_string()),
];
for (kvar, default_v) in vars {
if env::var(kvar).is_err() {
env::set_var(kvar, default_v);
}
}
Ok(())
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# This file is included as a string in subprocess.rs. Most work should be done in the Rust caller.
#
import json
import logging
import multiprocessing
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.usage.usage_lib import UsageContext
arg_map = {
"model": f"{model_path}",
"served_model_name": None,
"task": "generate",
"skip_tokenizer_init": True,
"seed": 0,
"max_model_len": 8192,
"max_seq_len_to_capture": 8192,
"tensor_parallel_size": int(tp_size_str),
"pipeline_parallel_size": int(nnodes_str),
"enable_prefix_caching": enable_prefix_caching.lower() == "true",
}
json_map = {}
if extra_engine_args != "":
# extra_engine_args is a filename
try:
with open(extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.debug(f"File {extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.debug(f"Invalid JSON in {extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
engine_args = AsyncEngineArgs(**arg_map)
ipc_path = f"ipc:///tmp/{socket_id}"
engine_alive = multiprocessing.Value("b", True, lock=False)
run_mp_engine(engine_args, UsageContext.OPENAI_API_SERVER, ipc_path, engine_alive)
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::env;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use std::vec::IntoIter;
use async_zmq::{SinkExt, StreamExt};
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::CancellationToken;
use pyo3::{
prelude::*,
types::{IntoPyDict, PyBytes, PyString},
};
use tokio::io::AsyncBufReadExt;
use tokio::sync::mpsc::{error::SendError, Sender};
use tokio::task::JoinHandle;
use dynamo_llm::kv_router::protocols::ForwardPassMetrics;
use dynamo_llm::protocols::common::llm_backend::LLMEngineOutput;
use dynamo_llm::protocols::common::preprocessor::PreprocessedRequest;
use dynamo_llm::protocols::common::FinishReason;
use dynamo_llm::{engines::MultiNodeConfig, kv_router::publisher::KvMetricsPublisher};
/// Wait this long for the vllm sub-process to stop after we send it a KILL
const VLLM_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
// The minor revision version of vllm that this engine supports. 0.8+ is in a different engine.
const VLLM_VERSION: &str = "0.7";
type RequestID = String;
pub struct VllmWorker {
/// How we receive work requests
tx: Sender<WorkRequest>,
/// Handle of the task that reads from `tx` and forwards those requests over zmq to vllm
_input_loop: JoinHandle<()>,
/// Handle of the task that reads vllm's responses from zmq and dispatches them to the correct
/// active request.
_output_loop: JoinHandle<()>,
/// Handle of the vllm background process
vllm: Option<JoinHandle<()>>,
// We don't need to hold on to this, it's already shared between input_loop and output_loop
// But later we'll probably want stats - how many active requests etc, so keep it here
_active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
// Need to keep this alive
// TODO: With async_zmq we possibly don't need this at all
#[allow(dead_code)]
zmq_context: async_zmq::Context,
}
/// How we get asked to do some work. These get unpacked and forwarded to vllm.
pub struct WorkRequest {
pub request: PreprocessedRequest,
pub request_id: RequestID,
pub response_channel: Sender<Annotated<LLMEngineOutput>>,
}
/// A request currently being process by vllm
struct ActiveRequest {
tx: Sender<Annotated<LLMEngineOutput>>,
num_output_tokens_so_far: usize,
}
/// Python imports
struct Imports {
pickle_module: PyObject,
tokens_prompt_type: PyObject,
sample_params_type: PyObject,
rpc_type: PyObject,
startup_type: PyObject,
}
/// All the zmq sockets we used. This object only used to passing them around to avoid large
/// tuples.
struct Sockets {
#[allow(dead_code)]
context: async_zmq::Context, // we have to keep this alive
// Control socket, how we ask vllm engine to start.
// Not the best name, but this is what vllm calls it internally.
data: async_zmq::Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
// Requests from us to the vllm engine
input: async_zmq::Push<IntoIter<Vec<u8>>, Vec<u8>>,
// Responses from the vllm engine back to us
output: async_zmq::Pull,
// Heartbeat messages from vllm process
heartbeat: async_zmq::Pull,
// NOTE: Metrics socket usage is custom to our patch of vllm, and may not
// be present when running upstream vllm.
// Metrics messages from vllm process
metrics: async_zmq::Pull,
}
/// The message vllm sends us over zmq when it's ready to work.
#[derive(FromPyObject, Debug)]
struct RPCStartupResponse {
#[allow(dead_code)]
tracing_enabled: bool,
}
/// What vllm sends us. Usually it contains a single token.
#[allow(dead_code)]
#[derive(FromPyObject, Debug)]
pub struct RequestOutput {
request_id: String,
prompt: Option<String>,
prompt_token_ids: Option<Vec<u32>>,
prompt_logprobs: Option<Vec<Option<HashMap<u32, Logprob>>>>,
outputs: Vec<CompletionOutput>,
finished: bool,
//metrics: Optional[RequestMetrics] = None,
//lora_request: Optional[LoRARequest] = None,
encoder_prompt: Option<String>,
encoder_prompt_token_ids: Option<Vec<u32>>,
num_cached_tokens: Option<u32>,
}
#[allow(dead_code)]
#[derive(FromPyObject, Debug)]
pub struct CompletionOutput {
index: u32,
text: String,
token_ids: Vec<u32>,
cumulative_logprob: Option<f32>,
logprobs: Option<Vec<HashMap<u32, Logprob>>>,
finish_reason: Option<String>,
//stop_reason: Union[int, str, None] = None
//lora_request: Optional[LoRARequest] = None
}
#[allow(dead_code)]
#[derive(FromPyObject, Debug)]
struct Logprob {
logprob: f32,
rank: Option<u32>,
decoded_token: Option<String>,
}
/// Main entry point
pub async fn start(
cancel_token: CancellationToken,
sock_code: &str,
model_path: &Path,
_node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>,
// When using our vllm fork, this is how we publish it's KV metrics for the KV router
kv_metrics_publisher: Option<Arc<KvMetricsPublisher>>,
) -> anyhow::Result<VllmWorker> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
if let Ok(venv) = env::var("VIRTUAL_ENV") {
let _ = Python::with_gil(|py| crate::fix_venv(venv, py));
}
let py_imports = Arc::new(python_imports());
let Sockets {
context,
data,
input,
output,
heartbeat,
metrics,
} = zmq_sockets(sock_code)?;
let vllm_process = start_vllm(
model_path,
&py_imports,
data,
tensor_parallel_size,
extra_engine_args,
kv_metrics_publisher.is_some(),
)
.await?;
let vllm_join_handle = watch_vllm(cancel_token.clone(), vllm_process);
tokio::spawn(heartbeat_loop(cancel_token.clone(), heartbeat));
tokio::spawn(metrics_loop(
cancel_token.clone(),
metrics,
kv_metrics_publisher.clone(),
));
let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let (tx, rx) = tokio::sync::mpsc::channel(8);
let input_loop_handle = {
let cancel_token = cancel_token.clone();
let py_imports = py_imports.clone();
let active_requests = active_requests.clone();
tokio::spawn(input_loop(
cancel_token,
py_imports,
input,
active_requests,
rx,
))
};
let output_loop_handle = {
let cancel_token = cancel_token.clone();
let py_imports = py_imports.clone();
let active_requests = active_requests.clone();
tokio::spawn(output_loop(
cancel_token,
py_imports,
output,
active_requests,
))
};
Ok(VllmWorker {
tx,
zmq_context: context,
_input_loop: input_loop_handle,
_output_loop: output_loop_handle,
vllm: Some(vllm_join_handle),
_active_requests: active_requests,
})
}
/// Import all the python packages we'll need. `vllm` particularly takes a few seconds.
fn python_imports() -> Imports {
Python::with_gil(|py| {
let pickle_module: PyObject = match py.import("pickle") {
Ok(m) => m.into(),
Err(err) => {
// There is no vllm without python
panic!("Failed to import python 'pickle' module. Is Python installed? {err}");
}
};
let vllm_module: PyObject = match py.import("vllm") {
Ok(m) => m.into(),
Err(err) => {
panic!("Failed to import python 'vllm' module. Are we running in the correct venv? {err}");
}
};
// While we're here check vllm version
let version = vllm_module
.getattr(py, "__version__")
.expect("vllm missing __version__ field")
.extract::<String>(py)
.expect("vllm.__version__ is not a string");
if !version.starts_with(VLLM_VERSION) {
panic!("Expected vllm version {VLLM_VERSION}, found {version}");
}
let tokens_prompt_type: PyObject = vllm_module.getattr(py, "TokensPrompt").unwrap();
let sample_params_type: PyObject = vllm_module.getattr(py, "SamplingParams").unwrap();
let mod_multiprocessing = py.import("vllm.engine.multiprocessing").unwrap();
let rpc_type: PyObject = mod_multiprocessing
.getattr("RPCProcessRequest")
.unwrap()
.into();
let startup_type: PyObject = mod_multiprocessing
.getattr("RPCStartupRequest")
.unwrap()
.into();
Imports {
pickle_module,
tokens_prompt_type,
sample_params_type,
rpc_type,
startup_type,
}
})
}
/// Create all the zmq sockets we're going to use.
fn zmq_sockets(sock_code: &str) -> anyhow::Result<Sockets> {
let zmq_context = async_zmq::Context::new();
let input = async_zmq::push(&format!("ipc:///tmp/{sock_code}_input_socket"))?
.with_context(&zmq_context)
.connect()?;
let output = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_output_socket"))?
.with_context(&zmq_context)
.connect()?;
let data = async_zmq::dealer(&format!("ipc:///tmp/{sock_code}_data_socket"))?
.with_context(&zmq_context)
.connect()?;
let heartbeat = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_health_socket"))?
.with_context(&zmq_context)
.connect()?;
let metrics = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_metrics_socket"))?
.with_context(&zmq_context)
.connect()?;
// TODO: NIXL/Prefill sockets here in the future for disagg?
Ok(Sockets {
context: zmq_context,
data,
input,
output,
heartbeat,
metrics,
})
}
/// Start the vllm python sub-process and wait for it to start
async fn start_vllm(
model_path: &Path,
python_imports: &Imports,
mut data_socket: async_zmq::Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>,
with_kv_routing: bool,
) -> anyhow::Result<tokio::process::Child> {
let mut vllm_args = vec![
"--internal-vllm-process".to_string(),
format!("--model-path={}", model_path.display()),
format!("--tensor-parallel-size={tensor_parallel_size}"),
];
if let Some(args_path) = extra_engine_args {
vllm_args.push(format!("--extra-engine-args={}", args_path.display()));
}
if with_kv_routing {
vllm_args.push("--router-mode=kv".to_string());
}
let self_path = std::env::current_exe()?;
let mut proc = tokio::process::Command::new(self_path)
.env("VLLM_LOGGING_LEVEL", "DEBUG")
.args(&vllm_args)
.kill_on_drop(false)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let stdout = tokio::io::BufReader::new(proc.stdout.take().unwrap());
let stderr = tokio::io::BufReader::new(proc.stderr.take().unwrap());
tokio::spawn(async move {
let mut lines = stdout.lines();
while let Ok(Some(line)) = lines.next_line().await {
let mut line_parts = line.splitn(4, ' ');
let mut log_level = line_parts.next().unwrap_or_default();
// Skip date (0) and time (1). Print last (2) which is everything else.
let line = line_parts.nth(2).unwrap_or_default();
if line.starts_with("custom_op.py:68") || line.trim().is_empty() {
// Skip a noisy line
// custom_op.py:68] custom op <the op> enabled
continue;
}
if line.contains("ERROR") {
log_level = "ERROR";
}
match log_level {
"DEBUG" => tracing::debug!("VLLM: {line}"),
"INFO" => tracing::debug!("VLLM: {line}"), // VLLM is noisy in debug mode
"WARNING" => tracing::warn!("VLLM: {line}"),
"ERROR" => tracing::error!("VLLM: {line}"),
level => tracing::info!("VLLM: {level} {line}"),
}
}
});
tokio::spawn(async move {
let mut lines = stderr.lines();
while let Ok(Some(line)) = lines.next_line().await {
if line.trim().is_empty() {
continue;
}
tracing::warn!("VLLM: {line}");
}
});
let start_req_bytes: Vec<u8> = Python::with_gil(|py| {
let start_req = python_imports
.startup_type
.getattr(py, "IS_SERVER_READY")
.unwrap();
let pickle_dumps = python_imports.pickle_module.getattr(py, "dumps").unwrap();
pickle_dumps
.call1(py, (start_req,))
.unwrap()
.extract(py)
.unwrap()
});
data_socket.send(vec![start_req_bytes].into()).await?;
let start_resp: Vec<u8> = match data_socket.next().await {
Some(Ok(r)) => {
if !r.is_empty() {
r[0].deref().to_vec()
} else {
anyhow::bail!("vllm failed to start. No response on dealer/data socket");
}
}
Some(Err(err)) => {
anyhow::bail!("vllm failed to start. Error reading from dealer/data socket: {err}");
}
None => {
anyhow::bail!("vllm failed to start. dealer/data socket is closed.");
}
};
let resp: RPCStartupResponse = Python::with_gil(|py| {
let pickle_loads = python_imports.pickle_module.getattr(py, "loads").unwrap();
pickle_loads
.call1(py, (start_resp,))
.unwrap()
.extract(py)
.unwrap()
});
tracing::debug!("vllm zmq backend is ready: {resp:?}");
Ok(proc)
}
// Stop the vllm process when we stop, and prevent it going zombie.
fn watch_vllm(
cancel_token: CancellationToken,
mut vllm_process: tokio::process::Child,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
cancel_token.cancelled().await;
tokio::select! {
_ = vllm_process.wait() => {
return;
},
_ = tokio::time::sleep(VLLM_STOP_TIMEOUT) => { }
}
if let Err(err) = vllm_process.start_kill() {
tracing::error!("Failing killing vllm subprocess: {err}");
return;
}
tokio::select! {
_ = vllm_process.wait() => { },
_ = tokio::time::sleep(VLLM_STOP_TIMEOUT) => {
tracing::warn!("Timeout waiting for vllm sub-process to stop after kill");
}
}
})
}
// How we know vllm engine is alive. It sends "SUCCESS" as a pickled string every 10s.
// Runs outside of tokio on a regular thread.
// TODO: If we don't get heartbeats we should, euh, do something. vllm is gone. At least
// de-register the model.
async fn heartbeat_loop(cancel_token: CancellationToken, mut socket: async_zmq::Pull) {
loop {
let maybe_hb = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_hb = socket.next() => {
maybe_hb
}
};
let b = match maybe_hb {
Some(Ok(b)) => b[0].deref().to_vec(),
Some(Err(err)) => {
tracing::error!("Error reading from vllm heartbeat socket: {err}");
break;
}
None => {
tracing::debug!("vllm heartbeat socket closed");
break;
}
};
let s: String = match serde_pickle::from_slice(&b, Default::default()) {
Ok(s) => s,
Err(err) => {
tracing::error!("Error de-serializing vllm heartbeat response. It was probably Exception not str. {err}");
break;
}
};
if s != "SUCCESS" {
tracing::error!("vllm heartbeat error, expected 'SUCCESS' got '{s}'");
break;
}
}
}
// NOTE: Custom to our patch of vllm.
async fn metrics_loop(
cancel_token: CancellationToken,
mut socket: async_zmq::Pull,
publisher: Option<Arc<KvMetricsPublisher>>,
) {
loop {
let maybe_metrics = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_metrics = socket.next() => {
maybe_metrics
}
};
let b = match maybe_metrics {
Some(Ok(b)) => b[0].deref().to_vec(),
Some(Err(err)) => {
tracing::error!("Error reading from vllm metrics socket: {err}");
break;
}
None => {
tracing::debug!("vllm metrics socket closed");
break;
}
};
// Try to deserialize directly into ForwardPassMetrics using Python's pickle module
let metrics_result = Python::with_gil(|py| -> Result<ForwardPassMetrics, String> {
let pickle = py
.import("pickle")
.map_err(|e| format!("Failed to import pickle: {}", e))?;
let loads = pickle
.getattr("loads")
.map_err(|e| format!("Failed to get loads function: {}", e))?;
let bytes = PyBytes::new(py, &b);
let result = loads
.call1((bytes,))
.map_err(|e| format!("Failed to call pickle.loads: {}", e))?;
// Try to extract the attributes from the Python object
let extract_field = |field: &str| -> Result<u64, String> {
result
.getattr(field)
.map_err(|e| format!("Field '{}' not found: {}", field, e))?
.extract::<u64>()
.map_err(|e| format!("Failed to extract '{}' as u64: {}", field, e))
};
let extract_float_field = |field: &str| -> Result<f32, String> {
result
.getattr(field)
.map_err(|e| format!("Field '{}' not found: {}", field, e))?
.extract::<f32>()
.map_err(|e| format!("Failed to extract '{}' as f32: {}", field, e))
};
// Give default values for any fields not found
let request_active_slots = extract_field("request_active_slots").unwrap_or(0);
let request_total_slots = extract_field("request_total_slots").unwrap_or(0);
let kv_active_blocks = extract_field("kv_active_blocks").unwrap_or(0);
let kv_total_blocks = extract_field("kv_total_blocks").unwrap_or(0);
let num_requests_waiting = extract_field("num_requests_waiting").unwrap_or(0);
let gpu_cache_usage_perc = extract_float_field("gpu_cache_usage_perc").unwrap_or(0.0);
let gpu_prefix_cache_hit_rate =
extract_float_field("gpu_prefix_cache_hit_rate").unwrap_or(0.0);
Ok(ForwardPassMetrics {
request_active_slots,
request_total_slots,
kv_active_blocks,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
})
});
match metrics_result {
Ok(metrics) => {
if let Some(metrics_publisher) = publisher.as_ref() {
if let Err(err) = metrics_publisher.publish(metrics.into()) {
tracing::error!(%err, "Failed publishing KV metrics");
}
}
}
Err(err) => {
tracing::error!("Error deserializing vllm metrics with Python pickle: {err}");
}
}
}
}
fn from_vllm(output: CompletionOutput, previous_total_toks: usize) -> LLMEngineOutput {
let finish_reason = match output.finish_reason.as_deref() {
Some("stop") => Some(FinishReason::Stop),
Some("abort") => Some(FinishReason::Cancelled),
Some("length") => Some(FinishReason::Length),
Some(unknown) => {
tracing::info!("Unknown vllm stop reason '{unknown}'. Please add to vllm.rs");
Some(FinishReason::Stop)
}
None => None,
};
LLMEngineOutput {
// todo - propagate mdcsum
token_ids: output.token_ids[previous_total_toks..].into(),
tokens: None,
text: None,
//text: if output.text.is_empty() { None } else { Some(output.text) },
cum_log_probs: output.cumulative_logprob.map(|v| v as f64),
log_probs: None, // TODO output.logprobs
finish_reason,
}
}
async fn input_loop(
cancel_token: CancellationToken,
py_imports: Arc<Imports>,
mut input_socket: async_zmq::Push<IntoIter<Vec<u8>>, Vec<u8>>,
active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
mut rx: tokio::sync::mpsc::Receiver<WorkRequest>,
) {
loop {
let work_request = tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("VllmWorker.input_loop exit");
break;
}
req = rx.recv() => {
match req {
Some(req) => req,
None => {
tracing::trace!("VllmWorker input_loop socket closed");
break;
}
}
}
};
let request_id = work_request.request_id;
let token_ids = work_request.request.token_ids.clone();
let temperature: f64 = work_request
.request
.sampling_options
.temperature
.unwrap_or(0.0)
.into();
// Parts that don't change
let (py_request_id, sampling_params) = Python::with_gil(|py| {
let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into();
let mut sp_kwargs = vec![("temperature", py_temp)];
if let Some(max_tokens) = work_request.request.stop_conditions.max_tokens {
let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into();
// vllm defaults this to 16
sp_kwargs.push(("max_tokens", py_max_tokens));
}
let sp_kwargs = sp_kwargs.into_py_dict(py).unwrap();
let sampling_params = py_imports
.sample_params_type
.call(py, (), Some(&sp_kwargs))
.unwrap();
let py_request_id: PyObject = PyString::new(py, &request_id).into();
(py_request_id, sampling_params)
});
let pickled_req: Vec<u8> = Python::with_gil(|py| {
let token_prompt_kwargs = [("prompt_token_ids", token_ids.clone())]
.into_py_dict(py)
.unwrap();
let prompt_obj = py_imports
.tokens_prompt_type
.call(py, (), Some(&token_prompt_kwargs))
.unwrap();
let rpc_kwargs = [
("prompt", prompt_obj),
("params", sampling_params.clone()),
("request_id", py_request_id.clone()),
]
.into_py_dict(py)
.unwrap();
let req = py_imports.rpc_type.call(py, (), Some(&rpc_kwargs)).unwrap();
let pickle_dumps = py_imports.pickle_module.getattr(py, "dumps").unwrap();
pickle_dumps.call1(py, (req,)).unwrap().extract(py).unwrap()
});
let new_active_request = ActiveRequest {
tx: work_request.response_channel,
num_output_tokens_so_far: 0,
};
active_requests
.lock()
.await
.insert(request_id, new_active_request);
if let Err(err) = input_socket.send(vec![pickled_req].into()).await {
tracing::error!("Error sending new request to vllm over zmq: {err}");
}
}
}
/// Read from vllm's output zmq socket, find which request it is for and forward over that channel.
async fn output_loop(
cancel_token: CancellationToken,
py_imports: Arc<Imports>,
mut output_socket: async_zmq::Pull,
active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
) {
loop {
let mut bb = tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("VllmWorker.output_loop exit");
break;
}
from_vllm = output_socket.next() => {
match from_vllm {
Some(Ok(b)) => b,
Some(Err(err)) => {
tracing::error!("Error reading from vllm zmq output: {err}");
continue; // hope lives eternal
}
None => {
tracing::debug!("zmq output socket closed");
break;
}
}
}
};
let frame = bb.remove(0);
let mut reqs_out: Vec<RequestOutput> = Python::with_gil(|py| {
let pickle_loads = py_imports.pickle_module.getattr(py, "loads").unwrap();
let frame_bytes = PyBytes::new(py, &frame);
pickle_loads
.call1(py, (frame_bytes,))
.unwrap()
.extract(py)
.unwrap()
});
if reqs_out.is_empty() {
tracing::debug!("Received message from vllm with no content");
continue;
}
let req_out = reqs_out.remove(0);
if req_out.finished {
// The last token is the eos_token, don't forward it
// TODO: Look at req_out.finish_reason (Option<String>) and set out correctly.
let out = Annotated::from_data(LLMEngineOutput::stop());
let maybe_active = active_requests.lock().await.remove(&req_out.request_id);
match maybe_active {
Some(active) => {
let _ = active.tx.send(out).await;
}
None => {
tracing::warn!(
req_out.request_id,
"Missing active request to notify of stop"
);
}
}
continue;
}
for vllm_output in req_out.outputs.into_iter() {
let next_total_toks = vllm_output.token_ids.len();
match active_requests.lock().await.get_mut(&req_out.request_id) {
Some(active) => {
let out = from_vllm(vllm_output, active.num_output_tokens_so_far);
active.num_output_tokens_so_far = next_total_toks;
let _ = active.tx.send(Annotated::from_data(out)).await;
}
None => {
tracing::warn!(req_out.request_id, "Missing active request");
}
}
}
}
}
impl VllmWorker {
/// Send a request to vllm
pub async fn enqueue_request(&self, r: WorkRequest) -> Result<(), SendError<WorkRequest>> {
self.tx.send(r).await
}
/// Get the vllm sub-process handle, so we can await it and prevent it going zombie.
pub fn take_vllm_handle(&mut self) -> JoinHandle<()> {
self.vllm.take().unwrap()
}
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name = "dynamo-engine-vllm0_8"
version.workspace = true
edition.workspace = true
description.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
keywords.workspace = true
[dependencies]
dynamo-runtime = { workspace = true }
dynamo-llm = { workspace = true }
anyhow = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-stream = { workspace = true }
tracing = { workspace = true }
async-openai = "0.27.2"
pyo3 = { version = "0.23.3", default-features = false, features = [
"macros",
"experimental-async",
"experimental-inspect",
"py-clone",
] }
pyo3-async-runtimes = { version = "0.23.0", default-features = false, features = [
"attributes",
"testing",
"tokio-runtime",
"unstable-streams",
] }
pythonize = { version = "0.23" }
regex = "1"
serde-pickle = "1.2.0"
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ffi::CString;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;
use std::{path::Path, sync::Arc};
use async_stream::stream;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::protocols::common::FinishReason;
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::pipeline::{async_trait, Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::{CancellationToken, Result};
use pyo3_async_runtimes::TaskLocals;
use serde::Deserialize;
use tokio::sync::oneshot::Sender;
use dynamo_llm::backend::ExecutionContext;
use dynamo_llm::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict, PyString, PyTuple};
use pyo3::PyObject;
use pyo3::Python;
use pythonize::pythonize;
use tokio_stream::StreamExt;
// The minor revision version of vllm that this engine supports
const VLLM_VERSION: &str = "0.8";
const PY_MAIN_MOD: &str = include_str!("vllm_inc.py");
const VLLM_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(3);
pub async fn make_engine(
cancel_token: CancellationToken,
// Full path to the model, either a GGUF file or an HF repo dir
model_path: &Path,
// Multi node settings
node_conf: MultiNodeConfig,
// How many GPUs to use
tensor_parallel_size: u32,
// Path to extra engine args file
extra_engine_args: Option<PathBuf>,
) -> pipeline_error::Result<ExecutionContext> {
let engine = VllmEngine::new(
cancel_token,
model_path,
node_conf,
tensor_parallel_size,
extra_engine_args,
)
.await?;
let engine: ExecutionContext = Arc::new(engine);
Ok(engine)
}
struct VllmEngine {
cancel_token: CancellationToken,
// How we send requests to Python / vllm
request_queue: Arc<PyObject>,
// asyncio event loop to run all our python futures. vllm is async.
event_loop: Arc<PyObject>,
// The python module from vllm_inc.py
py_main_mod: Arc<PyObject>,
// vllm.SamplingParams
sampling_params: PyObject,
}
impl VllmEngine {
pub async fn new(
cancel_token: CancellationToken,
model_path: &Path,
node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>,
) -> anyhow::Result<Self> {
pyo3::prepare_freethreaded_python();
// Safety: CString::new will only return an error if the string contains an null byte.
let py_main_mod: PyObject = Python::with_gil(|py| -> PyResult<PyObject> {
PyModule::from_code(
py,
&CString::new(PY_MAIN_MOD).expect("vllm_inc.py contains a null byte!"),
&CString::new("_synthetic/dynamo_engine_vllm.py").unwrap(),
&CString::new("dynamo_engine_vllm").unwrap(),
)
.map(|p| p.into())
})?;
let py_main_mod = Arc::new(py_main_mod);
let sampling_params = sampling_params_type();
let (request_queue_rs, request_queue_py) = make_python_queues(64)?;
let (ready_event_rs, ready_event_py) = make_python_event()?;
let (tx, rx) = tokio::sync::oneshot::channel();
let model_path_buf = PathBuf::from(model_path);
let cancel_token_worker = cancel_token.clone();
let py_main_mod_worker = py_main_mod.clone();
tokio::task::spawn(async move {
if let Err(err) = run_vllm_worker(
cancel_token_worker,
tx,
py_main_mod_worker,
request_queue_py,
ready_event_py,
&model_path_buf,
node_conf,
tensor_parallel_size,
extra_engine_args,
)
.await
{
tracing::error!(%err, "run_vllm_worker error");
}
});
let event_loop = tokio::select! {
ev = rx => ev,
_ = cancel_token.cancelled() => {
anyhow::bail!("VllmEngine create cancelled");
}
};
let event_loop = event_loop?;
// Wait for vllm to start accepting requests
tokio::select! {
_ = wait_for_vllm(event_loop.clone(), ready_event_rs) => {
tracing::trace!("vllm worker is ready");
}
_ = cancel_token.cancelled() => {
anyhow::bail!("VllmEngine cancelled waiting for vllm to start");
}
};
let engine = VllmEngine {
cancel_token,
request_queue: Arc::new(request_queue_rs),
event_loop,
py_main_mod,
sampling_params,
};
Ok(engine)
}
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for VllmEngine
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let (request, context) = request.into_parts();
let ctx = context.context();
let request_id = ctx.id().to_string();
let temperature: f64 = request.sampling_options.temperature.unwrap_or(0.0).into();
// Send request
let (response_queue_1, response_queue_2) = make_python_queues(16)?;
let queue_fut = Python::with_gil(|py| {
let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into();
let mut sp_kwargs = vec![("temperature", py_temp)];
if let Some(max_tokens) = request.stop_conditions.max_tokens {
let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into();
// vllm defaults this to 16
sp_kwargs.push(("max_tokens", py_max_tokens));
}
let sp_kwargs = sp_kwargs.into_py_dict(py).unwrap();
let sampling_params = self.sampling_params.call(py, (), Some(&sp_kwargs)).unwrap();
let py_request = pythonize(py, &request)?;
let args: Vec<PyObject> = vec![
PyString::new(py, &request_id).into(),
py_request.into(),
sampling_params,
response_queue_1,
];
let put_arg = PyTuple::new(py, args)?;
let locals = TaskLocals::new(self.event_loop.bind(py).clone());
pyo3_async_runtimes::into_future_with_locals(
&locals.clone_ref(py),
self.request_queue
.bind(py)
.call_method1("put", (put_arg,))?,
)
})?;
queue_fut.await?;
// Read response
let from_vllm = Python::with_gil(|py| {
let locals = TaskLocals::new(self.event_loop.bind(py).clone());
pyo3_async_runtimes::tokio::into_stream_with_locals_v1(
locals,
self.py_main_mod
.bind(py)
.call_method1("run_response", (response_queue_2.bind(py),))?,
)
})?;
let mut from_vllm = Box::pin(from_vllm);
let cancel_token = self.cancel_token.clone();
let output = stream! {
let mut num_output_tokens_so_far = 0;
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!(request_id, "VllmEngine.generate stopped by cancel token");
break;
}
item = from_vllm.next() => {
match item {
None => {
yield Annotated::from_data(LLMEngineOutput::stop());
break;
},
Some(item) => {
match vllm_to_dynamo(item).await {
Ok(Some(mut response)) => {
// The response includes all the tokens.
// We only want the delta.
if response.token_ids.is_empty() {
yield Annotated::from_data(response);
break;
} else {
let next_total_toks = response.token_ids.len();
drop(response.token_ids.drain(0..num_output_tokens_so_far));
num_output_tokens_so_far = next_total_toks;
yield Annotated::from_data(response);
}
}
Ok(None) => {
yield Annotated::from_data(LLMEngineOutput::stop());
break;
},
Err(err) => {
tracing::error!(request_id, %err, "vllm_to_dynamo error");
break;
}
}
}
}
}
} // tokio::select!
} // loop
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
// TODO this will panic if invalid vllm, can't do anything with wrong vllm.
// But should we return an error instead?
fn sampling_params_type() -> PyObject {
Python::with_gil(|py| {
let vllm_module: PyObject = match py.import("vllm") {
Ok(m) => m.into(),
Err(err) => {
panic!("Failed to import python 'vllm' module. Are we running in the correct venv? {err}");
}
};
let version = vllm_module
.getattr(py, "__version__")
.expect("vllm missing __version__ field")
.extract::<String>(py)
.expect("vllm.__version__ is not a string");
if !version.starts_with(VLLM_VERSION) {
panic!("Expected vllm version {VLLM_VERSION}, found {version}");
}
let sample_params_type: PyObject = vllm_module
.getattr(py, "SamplingParams")
.expect("vllm module missing SamplingParams type.");
sample_params_type
})
}
/// Create a Python asyncio.Queue. Return two copies of it.
fn make_python_queues(max_size: usize) -> anyhow::Result<(PyObject, PyObject)> {
Python::with_gil(|py| -> Result<(PyObject, PyObject), String> {
let module: PyObject = match py.import("asyncio") {
Ok(m) => m.into(),
Err(err) => {
panic!("Failed to import python 'asyncio' module. Is Python installed? {err}");
}
};
let kwargs = PyDict::new(py);
kwargs
.set_item("maxsize", max_size)
.map_err(|err| format!("Failed setting maxsize in dict to {max_size}: {err}"))?;
let q = module
.call_method(py, "Queue", (), Some(&kwargs))
.map_err(|e| format!("Failed to call asyncio.Queue: {}", e))?;
Ok((q.clone(), q))
})
.map_err(|err| anyhow::anyhow!("{err}"))
}
fn make_python_event() -> anyhow::Result<(PyObject, PyObject)> {
Python::with_gil(|py| -> Result<(PyObject, PyObject), String> {
let module: PyObject = match py.import("asyncio") {
Ok(m) => m.into(),
Err(err) => {
panic!("Failed to import python 'asyncio' module. Is Python installed? {err}");
}
};
let ev = module
.call_method0(py, "Event")
.map_err(|e| format!("Failed to call asyncio.Event: {}", e))?;
Ok((ev.clone(), ev))
})
.map_err(|err| anyhow::anyhow!("{err}"))
}
/// Start asyncio event loop and block on it forever
#[allow(clippy::too_many_arguments)]
async fn run_vllm_worker(
cancel_token: CancellationToken,
tx: Sender<Arc<PyObject>>,
py_main_mod: Arc<PyObject>,
request_queue: PyObject, // asyncio.Queue
ready_event: PyObject, // asyncio.Event
model_path: &Path,
node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
extra_engine_args: Option<PathBuf>,
) -> anyhow::Result<()> {
let model_path_str = model_path.display().to_string();
let extra_engine_args_str = &extra_engine_args
.map(|p| p.display().to_string())
.unwrap_or_default();
let event_loop: PyObject = Python::with_gil(|py| -> PyResult<PyObject> {
let aio: PyObject = py.import("asyncio").map(|p| p.into())?;
aio.call_method0(py, "new_event_loop")
})?;
let event_loop = Arc::new(event_loop);
let _ = tx.send(event_loop.clone());
let event_loop_forever = event_loop.clone();
tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let _ = event_loop_forever.call_method0(py, "run_forever");
});
});
let vllm_fut = Python::with_gil(|py| {
// These go directly to vllm's AsyncEngineArgs
let kwargs: Vec<(&str, PyObject)> = vec![
("model", PyString::new(py, &model_path_str).into()),
("task", PyString::new(py, "generate").into()),
(
"skip_tokenizer_init",
// Safety: true always converts to python object
true.into_pyobject(py).unwrap().to_owned().into(),
),
(
"tensor_parallel_size",
// Safety: A u32 should always convert safely
tensor_parallel_size.into_pyobject(py).unwrap().into(),
),
(
"pipeline_parallel_size",
// Safety: A u32 should always convert safely
node_conf.num_nodes.into_pyobject(py).unwrap().into(),
),
(
"enable_prefix_caching",
// Safety: true always converts to python object
true.into_pyobject(py).unwrap().to_owned().into(),
),
];
let kwargs = kwargs.into_py_dict(py)?;
let locals = TaskLocals::new(event_loop.bind(py).clone());
pyo3_async_runtimes::into_future_with_locals(
&locals.clone_ref(py),
py_main_mod
.call_method(
py,
"main",
(
request_queue.bind(py),
ready_event.bind(py),
extra_engine_args_str,
),
Some(&kwargs),
)?
.into_bound(py),
)
})?;
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("VllmEngine worker stopped by cancel token");
// Stop vllm
let vllm_stop_fut = Python::with_gil(|py| {
let locals = TaskLocals::new(event_loop.bind(py).clone());
pyo3_async_runtimes::into_future_with_locals(
&locals.clone_ref(py),
request_queue.call_method1(py, "put", (py.None(),))?.into_bound(py)
)
})?;
tokio::select! {
_ = vllm_stop_fut => {}
_ = tokio::time::sleep(VLLM_SHUTDOWN_TIMEOUT) => {
tracing::warn!("Timeout waiting for vllm to shut down. Process may still be running");
}
};
}
_ = vllm_fut => {
tracing::warn!("VllmEngine worker unexpected worker task completed");
}
}
Ok(())
}
#[derive(Debug, thiserror::Error)]
enum ResponseProcessingError {
#[error("python exception: {0}")]
PythonException(String),
}
#[derive(Debug, Clone, Deserialize, FromPyObject)]
pub struct CompletionOutput {
pub index: usize,
pub text: String,
pub token_ids: Vec<u32>,
pub cumulative_logprob: Option<f64>,
pub logprobs: Option<Vec<f64>>,
pub finish_reason: Option<String>,
pub stop_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, FromPyObject)]
pub struct RequestMetrics {
pub arrival_time: f64,
pub last_token_time: f64,
pub first_scheduled_time: f64,
pub first_token_time: f64,
pub time_in_queue: f64,
pub finished_time: Option<f64>,
pub scheduler_time: f64,
pub model_forward_time: Option<f64>,
pub model_execute_time: Option<f64>,
pub spec_token_acceptance_counts: Vec<u32>,
}
// Matches vllm Python type:
// RequestOutput(request_id=b87cf9dd-f66f-422f-ada9-b3e08c642c03, prompt=None, prompt_token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 374, 279, 6722, 315, 9625, 30, 151645, 198, 151644, 77091, 198], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='', token_ids=(785, 6722), cumulative_logprob=None, logprobs=None, finish_reason=None, stop_reason=None)], finished=False, metrics=RequestMetrics(arrival_time=1744820354.1364133, last_token_time=1744820354.206939, first_scheduled_time=1744820354.137031, first_token_time=1744820354.180786, time_in_queue=0.0006177425384521484, finished_time=None, scheduler_time=0.00044727100248564966, model_forward_time=None, model_execute_time=None, spec_token_acceptance_counts=[0]), lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})
#[derive(Debug, Clone, Deserialize, FromPyObject)]
pub struct RequestOutput {
pub request_id: String, // this is a uuid
pub prompt: Option<String>,
pub prompt_token_ids: Vec<u32>,
pub encoder_prompt: Option<String>,
pub encoder_prompt_token_ids: Option<Vec<u32>>,
pub prompt_logprobs: Option<Vec<f32>>,
pub outputs: Vec<CompletionOutput>,
pub finished: bool,
pub metrics: RequestMetrics,
//pub lora_request: Option<serde_json::Value>,
pub num_cached_tokens: usize,
//pub multi_modal_placeholders: HashMap<String, serde_json::Value>,
}
impl From<RequestOutput> for LLMEngineOutput {
fn from(mut req: RequestOutput) -> LLMEngineOutput {
if req.outputs.is_empty() {
// TODO should this be an error?
return LLMEngineOutput::stop();
}
let out = req.outputs.remove(0);
let finish_reason = out
.finish_reason
.map(|fr| match FinishReason::from_str(&fr) {
Ok(fr) => fr,
Err(err) => {
let s = format!("Unsupported finish reason from vllm: {fr}: {err}");
tracing::error!("{s}");
FinishReason::Error(s)
}
});
LLMEngineOutput {
token_ids: out.token_ids,
tokens: None,
text: None,
cum_log_probs: out.cumulative_logprob,
log_probs: out.logprobs,
finish_reason,
}
}
}
// Convert the vllm type to the dynamo type
async fn vllm_to_dynamo(
item: Result<Py<PyAny>, PyErr>,
) -> Result<Option<LLMEngineOutput>, ResponseProcessingError> {
// Handle errors first
let item = item.map_err(|e| {
println!();
Python::with_gil(|py| e.display(py));
ResponseProcessingError::PythonException(e.to_string())
})?;
// None is how Python tells us the request is complete
if Python::with_gil(|py| item.is_none(py)) {
return Ok(None);
}
Python::with_gil(|py| match item.extract::<RequestOutput>(py) {
Ok(response) => Ok(Some(response.into())),
Err(err) => {
tracing::trace!(%err, "Err extract python into RequestOutput. Usually means end of response.");
Ok(None)
}
})
}
async fn wait_for_vllm(
event_loop: Arc<PyObject>,
ready_event_rs: PyObject,
) -> anyhow::Result<PyObject> {
let maybe_py_fut = Python::with_gil(|py| -> PyResult<PyObject> {
ready_event_rs
.bind(py)
.call_method0("wait")
.map(|p| p.into())
});
let py_fut = match maybe_py_fut {
Ok(fut) => fut,
Err(err) => {
anyhow::bail!("Failed calling python event.wait() waiting for vllm to start: {err}");
}
};
let rs_fut = Python::with_gil(|py| {
let locals = TaskLocals::new(event_loop.bind(py).clone());
pyo3_async_runtimes::into_future_with_locals(&locals.clone_ref(py), py_fut.bind(py).clone())
})?;
rs_fut.await.map_err(|err| err.into())
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# This file is included as a string in lib.rs. Most work should be done in the Rust caller.
#
import json
import logging
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs import TokensPrompt
# TODO this should match DYN_LOG level
logging.basicConfig(level=logging.INFO)
async def main(request_queue, ready_event, extra_engine_args, **kwargs):
arg_map = kwargs
if extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
engine_args = AsyncEngineArgs(**arg_map)
# Main loop
try:
async with build_async_engine_client_from_engine_args(
engine_args
) as engine_client:
ready_event.set()
while True:
req = await request_queue.get()
if req is None: # Stop sentinel
break
(request_id, request, sampling_params, response_queue) = req
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
gen = engine_client.generate(prompt, sampling_params, request_id)
async for res in gen:
await response_queue.put(res)
await response_queue.put(None)
request_queue.task_done()
except Exception as e:
logging.error(f"vllm init failed: {e}")
finally:
logging.debug("vllm worker stopped")
async def run_response(response_queue):
try:
while True:
item = await response_queue.get()
yield item
response_queue.task_done()
if item is None:
return
except Exception as e:
logging.error(f"failed reading response from vllm: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment