Unverified Commit 19a77ae7 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(dynamo-run): Remove out=sglang|vllm|trtllm (#1920)

parent 3c500ae7
......@@ -64,46 +64,6 @@ pub struct Flags {
#[arg(long)]
pub model_config: Option<PathBuf>,
/// sglang, vllm
///
/// How many GPUs to use at once, total across all nodes.
/// This must divide by num_nodes, and each node must use the same number of GPUs.
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub tensor_parallel_size: u32,
/// sglang only
/// vllm uses CUDA_VISIBLE_DEVICES env var
///
/// Use GPUs from this ID upwards.
/// If your machine has four GPUs but the first two (0 and 1) are in use,
/// pass --base-gpu-id 2 to use the third GPU (and up, if tensor_parallel_size > 1)
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..256))]
pub base_gpu_id: u32,
/// vllm and sglang only
///
/// How many nodes/hosts to use
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub num_nodes: u32,
/// vllm and sglang only
///
/// This nodes' unique ID, running from 0 to num_nodes.
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..255))]
pub node_rank: u32,
/// For multi-node / pipeline parallel this is the <host>:<port> of the first node.
///
/// - vllm: The address/port of the Ray head node.
///
/// - sglang: The Torch Distributed init method address, in format <host>:<port>.
/// It becomes "tcp://<host>:<port>" when given to torch.distributed.init_process_group.
/// This expects to use the nccl backend (transparently to us here).
/// All nodes must use the same address here, which is node_rank == 0's address.
///
#[arg(long)]
pub leader_addr: Option<String>,
/// If using `out=dyn` with multiple instances, this says how to route the requests.
///
/// Mostly interesting for KV-aware routing.
......@@ -199,22 +159,6 @@ impl Flags {
}
#[cfg(feature = "mistralrs")]
Output::MistralRs => {}
Output::SgLang => {
if !local_model.path().is_dir() {
// TODO GGUF support for sglang: https://github.com/ai-dynamo/dynamo/issues/572
anyhow::bail!("`--model-path should point at a HuggingFace repo checkout");
}
}
Output::Vllm => {
if self.base_gpu_id != 0 {
anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
}
Output::Trtllm => {
if self.base_gpu_id != 0 {
anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
if !local_model.path().is_file() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::time::Duration;
use std::{future::Future, pin::Pin};
use anyhow::Context as _;
use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::entrypoint::EngineConfig;
......@@ -17,9 +14,6 @@ pub use flags::Flags;
mod opt;
pub use dynamo_llm::request_template::RequestTemplate;
pub use opt::Output;
mod subprocess;
const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2);
pub async fn run(
runtime: Runtime,
......@@ -48,6 +42,7 @@ pub async fn run(
.request_template(flags.request_template.clone())
.migration_limit(flags.migration_limit);
// TODO: old, address this later:
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one.
let mut rt = Either::Left(runtime.clone());
......@@ -71,7 +66,7 @@ pub async fn run(
flags.validate(&local_model, &out_opt)?;
// Make an engine from the local_model, flags and output.
let (engine_config, extra) = engine_for(
let engine_config = engine_for(
runtime.primary_token(),
out_opt,
flags.clone(),
......@@ -85,17 +80,9 @@ pub async fn run(
//
dynamo_llm::entrypoint::input::run_input(rt, in_opt, engine_config).await?;
// Allow engines to ask main thread to wait on an extra future.
// We use this to stop the vllm and sglang sub-process
if let Some(extra) = extra {
extra.await;
}
Ok(())
}
type ExtraFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
/// Create the engine matching `out_opt`
/// Note validation happens in Flags::validate. In here assume everything is going to work.
async fn engine_for(
......@@ -104,71 +91,27 @@ async fn engine_for(
flags: Flags,
local_model: LocalModel,
rt: Either<Runtime, DistributedRuntime>,
) -> anyhow::Result<(EngineConfig, Option<ExtraFuture>)> {
) -> anyhow::Result<EngineConfig> {
match out_opt {
Output::Dynamic => Ok((EngineConfig::Dynamic(Box::new(local_model)), None)),
Output::EchoFull => Ok((
EngineConfig::StaticFull {
Output::Dynamic => Ok(EngineConfig::Dynamic(Box::new(local_model))),
Output::EchoFull => Ok(EngineConfig::StaticFull {
model: Box::new(local_model),
engine: dynamo_llm::engines::make_engine_full(),
},
None,
)),
Output::EchoCore => Ok((
EngineConfig::StaticCore {
}),
Output::EchoCore => Ok(EngineConfig::StaticCore {
engine: dynamo_llm::engines::make_engine_core(),
model: Box::new(local_model),
},
None,
)),
}),
#[cfg(feature = "mistralrs")]
Output::MistralRs => Ok((
EngineConfig::StaticFull {
Output::MistralRs => Ok(EngineConfig::StaticFull {
engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
model: Box::new(local_model),
},
None,
)),
}),
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => Ok((
EngineConfig::StaticCore {
Output::LlamaCpp => Ok(EngineConfig::StaticCore {
engine: dynamo_engine_llamacpp::make_engine(cancel_token, &local_model).await?,
model: Box::new(local_model),
},
None,
)),
// For multi-node config. vllm uses `ray`, see guide
Output::Vllm => shell(subprocess::vllm::PY, cancel_token, local_model, flags, None).await,
// For multi-node config. trtlllm uses `mpi`, see guide
Output::Trtllm => {
shell(
subprocess::trtllm::PY,
cancel_token,
local_model,
flags,
None,
)
.await
}
Output::SgLang => {
let multi_node_config = if flags.num_nodes > 1 {
Some(dynamo_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
leader_addr: flags.leader_addr.clone().unwrap_or_default(),
})
} else {
None
};
shell(
subprocess::sglang::PY,
cancel_token,
local_model,
flags,
multi_node_config,
)
.await
}
}),
Output::Mocker => {
let Either::Right(drt) = rt else {
panic!("Mocker requires a distributed runtime to run.");
......@@ -180,76 +123,12 @@ async fn engine_for(
let engine =
dynamo_llm::mocker::engine::make_mocker_engine(drt, endpoint, args).await?;
Ok((
EngineConfig::StaticCore {
Ok(EngineConfig::StaticCore {
engine,
model: Box::new(local_model),
},
None,
))
}
}
}
async fn shell(
py_script: &'static str,
cancel_token: CancellationToken,
local_model: LocalModel,
flags: Flags,
multi_node_config: Option<dynamo_llm::engines::MultiNodeConfig>,
) -> anyhow::Result<(EngineConfig, Option<ExtraFuture>)> {
let (py_script, child) =
match subprocess::start(py_script, &local_model, flags.clone(), multi_node_config).await {
Ok(x) => x,
Err(err) => {
anyhow::bail!("Failed starting engine sub-process: {err}");
}
};
// Sub-process cleanup
let extra: ExtraFuture = Box::pin(async move {
stopper(cancel_token, child, py_script).await;
});
Ok((EngineConfig::Dynamic(Box::new(local_model)), Some(extra)))
}
/// Wait for cancel_token to be cancelled, then stop the child as gracefully as possible.
/// Keeps the TempPath alive until the child is stopped.
async fn stopper(
cancel_token: CancellationToken,
mut child: tokio::process::Child,
py_script: tempfile::TempPath,
) {
cancel_token.cancelled().await;
// Ask subprocess to stop gracefully
if let Some(pid) = child.id() {
unsafe { libc::kill(pid as i32, libc::SIGTERM) };
}
tokio::select! {
exit = child.wait() => {
tracing::trace!("engine sub-process graceful exit");
match exit {
Ok(exit_status) if exit_status.success() => {}
Ok(exit_status) => {
// This is nearly always 15 (SIGTERM)
tracing::trace!("engine sub-process non-0 exit: {exit_status}");
}
Err(err) => {
tracing::warn!("engine sub-process error getting exit status: {err}");
}
}
}
_ = tokio::time::sleep(CHILD_STOP_TIMEOUT) => {
// It didn't stop in time, kill it
child.kill().await.expect("Failed killing engine subprocess");
let _ = child.wait().await;
})
}
}
// This temporary file contains the python script running the engine. It deletes on drop.
// Keep it alive until the engine has stopped.
drop(py_script);
}
/// If the user will benefit from CUDA/Metal/Vulkan, remind them to build with it.
......
......@@ -90,6 +90,11 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
in_opt = Some(val.try_into()?);
}
"out" => {
if val == "sglang" || val == "trtllm" || val == "vllm" {
tracing::error!("To run the {val} engine please use the Python interface, see root README or look in directory `components/backends/`.");
std::process::exit(1);
}
out_opt = Some(val.try_into()?);
}
_ => {
......
......@@ -22,16 +22,6 @@ pub enum Output {
/// Run inference using llama.cpp
LlamaCpp,
/// Run inference using sglang
SgLang,
/// Run inference using trtllm
Trtllm,
// Start vllm in a sub-process connecting via nats
// Sugar for `python vllm_inc.py --endpoint <thing> --model <thing>`
Vllm,
Mocker,
}
......@@ -46,11 +36,7 @@ impl TryFrom<&str> for Output {
#[cfg(feature = "llamacpp")]
"llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp),
"sglang" => Ok(Output::SgLang),
"trtllm" => Ok(Output::Trtllm),
"vllm" => Ok(Output::Vllm),
"mocker" => Ok(Output::Mocker),
"echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore),
......@@ -79,11 +65,7 @@ impl fmt::Display for Output {
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => "llamacpp",
Output::SgLang => "sglang",
Output::Trtllm => "trtllm",
Output::Vllm => "vllm",
Output::Mocker => "mocker",
Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core",
......@@ -96,7 +78,11 @@ impl fmt::Display for Output {
impl Output {
#[allow(unused_mut)]
pub fn available_engines() -> Vec<String> {
let mut out = vec!["echo_core".to_string(), "echo_full".to_string()];
let mut out = vec![
"echo_core".to_string(),
"echo_full".to_string(),
Output::Mocker.to_string(),
];
#[cfg(feature = "mistralrs")]
{
out.push(Output::MistralRs.to_string());
......@@ -107,11 +93,6 @@ impl Output {
out.push(Output::LlamaCpp.to_string());
}
out.push(Output::SgLang.to_string());
out.push(Output::Trtllm.to_string());
out.push(Output::Vllm.to_string());
out.push(Output::Mocker.to_string());
out
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::borrow::Cow;
use std::io::Write;
use std::process::Stdio;
use std::sync::LazyLock;
use anyhow::Context;
use regex::Regex;
use tokio::io::AsyncBufReadExt;
use crate::flags::RouterMode;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::local_model::LocalModel;
pub mod sglang;
pub mod trtllm;
pub mod vllm;
pub async fn start(
// The Python code to run
py_script: &'static str,
// Model info
local_model: &LocalModel,
// Command line flags for user overrides
flags: super::Flags,
// sglang multi-node config. vllm uses `ray` externally
multi_node_config: Option<MultiNodeConfig>,
) -> anyhow::Result<(tempfile::TempPath, tokio::process::Child)> {
let mut tmp = tempfile::NamedTempFile::new()?;
// Writes on Linux don't block
tmp.write_all(py_script.as_bytes())?;
let script_path = tmp.into_temp_path();
let card = local_model.card();
let mut args = vec![
script_path.to_string_lossy().to_string(),
"--endpoint".to_string(),
local_model.endpoint_id().as_url(),
"--model-path".to_string(),
local_model.path().to_string_lossy().to_string(),
"--model-name".to_string(),
local_model.display_name().to_string(),
"--tensor-parallel-size".to_string(),
flags.tensor_parallel_size.to_string(),
"--kv-block-size".to_string(),
card.kv_cache_block_size.to_string(),
"--context-length".to_string(),
card.context_length.to_string(),
"--migration-limit".to_string(),
card.migration_limit.to_string(),
];
// TRTLLM only
// The worker node will only publish events and metrics if the router mode is KV
if flags.router_mode == RouterMode::KV {
args.push("--publish-events-and-metrics".to_string());
}
// sglang only
// vllm uses CUDA_VISIBLE_DEVICES
if flags.base_gpu_id != 0 {
args.push("--base-gpu-id".to_string());
args.push(flags.base_gpu_id.to_string());
}
// sglang only
if let Some(multi_node_config) = multi_node_config {
args.push("--nnodes".to_string());
args.push(multi_node_config.num_nodes.to_string());
args.push("--node-rank".to_string());
args.push(multi_node_config.node_rank.to_string());
args.push("--dist-init-addr".to_string());
args.push(multi_node_config.leader_addr);
}
if let Some(extra_engine_args) = flags.extra_engine_args {
args.push("--extra-engine-args".to_string());
args.push(extra_engine_args.to_string_lossy().to_string());
}
let mut cmd = tokio::process::Command::new("python3");
cmd.kill_on_drop(false)
.args(args)
.stdout(Stdio::piped())
.stderr(Stdio::piped());
let mut child = cmd
.spawn()
.with_context(|| format!("Failed running: '{}'", pretty_cmd(&cmd)))?;
// Safety: We set stdout/stderr a few lines above
let stdout = tokio::io::BufReader::new(child.stdout.take().unwrap());
let stderr = tokio::io::BufReader::new(child.stderr.take().unwrap());
tokio::spawn(async move {
let mut lines = stdout.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::info!("{}", strip_log_prefix(&line));
}
});
tokio::spawn(async move {
let mut lines = stderr.lines();
while let Ok(Some(line)) = lines.next_line().await {
// FIXME: always logging INFO/DEBUG will hide real errors, but
// some libraries log non-errors to stderr, which confuses users
// when we log those as ERROR. Using WARN as a middle ground for
// now, but we can probably be smarter here.
tracing::warn!("{}", strip_log_prefix(&line));
}
});
// We must keep temp path alive, it deletes on drop
Ok((script_path, child))
}
pub fn pretty_cmd(c: &tokio::process::Command) -> String {
format!(
"{} {}",
c.as_std().get_program().to_string_lossy(),
c.as_std()
.get_args()
.map(|x| x.to_string_lossy())
.collect::<Vec<std::borrow::Cow<'_, str>>>()
.join(" ")
)
}
// Thanks Gemini
static LOG_PREFIX_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r"^(?:(?:[A-Z]+ \d{2}-\d{2} \d{2}:\d{2}:\d{2})|(?:\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\])) (.*)$"
).unwrap()
// ^ Start of the line.
// (?: Non-capturing group for the two prefix alternatives.
// (?: Non-capturing group for the first prefix type.
// [A-Z]+ One or more uppercase letters (log level).
// (single space) A literal space.
// \d{2}-\d{2} Date like MM-DD.
// (single space) A literal space.
// \d{2}:\d{2}:\d{2} Time like HH:MM:SS.
// ) End of first prefix type group.
// | OR
// (?: Non-capturing group for the second prefix type.
// \[ A literal opening square bracket.
// \d{4}-\d{2}-\d{2} Date like YYYY-MM-DD.
// (single space) A literal space.
// \d{2}:\d{2}:\d{2} Time like HH:MM:SS.
// \] A literal closing square bracket.
// ) End of second prefix type group.
// ) End of the alternatives group.
// (single space) A literal space. This is the space BEFORE the message.
// (.*) Capture group 1: The rest of the line (the message).
// $ End of the line.
});
/// Strips the log level, date, and time from the start of a log line.
///
/// # Examples
/// let line = "INFO 05-06 09:38:50 [async_llm.py:252] Added request 1";
/// assert_eq!(strip_log_prefix(line), "[async_llm.py:252] Added request 1");
///
/// let line_no_prefix = "This is a normal line.";
/// assert_eq!(strip_log_prefix(line_no_prefix), "This is a normal line.");
fn strip_log_prefix(line: &str) -> Cow<'_, str> {
if let Some(captures) = LOG_PREFIX_RE.captures(line) {
// `captures.get(0)` would be the entire matched prefix + message.
// `captures.get(1)` is the first capture group, which is `(.*)`, the message itself.
if let Some(message_match) = captures.get(1) {
return Cow::Borrowed(message_match.as_str());
}
}
// If the regex doesn't match, or somehow the capture group is not found (shouldn't happen with (.*))
// return the original line.
Cow::Borrowed(line)
}
#[cfg(test)]
mod tests {
use super::strip_log_prefix;
#[test]
fn test_strip_log_prefix() {
let line = "INFO 05-06 09:38:50 [async_llm.py:252] Added request 1";
let expected = "[async_llm.py:252] Added request 1";
assert_eq!(strip_log_prefix(line), expected);
let line = "Just a regular line.";
assert_eq!(strip_log_prefix(line), line);
let line = "INFO this is not a full prefix";
assert_eq!(strip_log_prefix(line), line);
let line = "[2025-05-06 11:58:51] Capture cuda graph bs [1, 2, 4, 8]";
assert_eq!(strip_log_prefix(line), "Capture cuda graph bs [1, 2, 4, 8]");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// Source code of the SGLang sub-process
pub const PY: &str = include_str!("sglang_inc.py");
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=sglang` runs this script
# Can also be used standalone: `python3 sglang_inc.py` - lots of optional cmd line params
import argparse
import asyncio
import json
import logging
import sys
from typing import Optional
import sglang
import uvloop
from sglang.srt.entrypoints.engine import EmbeddingReqInput
from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
configure_dynamo_logging()
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
base_gpu_id: int
tensor_parallel_size: int
kv_block_size: int
context_length: int
nnodes: int
node_rank: int
dist_init_addr: str
migration_limit: int
extra_engine_args: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine):
self.engine_client = engine
async def generate(self, request):
sampling_params = {}
if request["sampling_options"]["temperature"] is not None:
sampling_params["temperature"] = request["sampling_options"]["temperature"]
sampling_params = {
# sglang defaults this to 128
"max_new_tokens": request["stop_conditions"]["max_tokens"],
}
# Check if this is a batch request
is_batch = "batch_token_ids" in request and request["batch_token_ids"]
if is_batch:
# Track tokens separately for each batch item
num_output_tokens_so_far = {}
logging.debug("received batch token ids")
gen = await self.engine_client.async_generate(
input_ids=request["batch_token_ids"],
sampling_params=sampling_params,
stream=True,
)
else:
num_output_tokens_so_far = 0
logging.debug("received token ids")
gen = await self.engine_client.async_generate(
input_ids=request["token_ids"],
sampling_params=sampling_params,
stream=True,
)
async for res in gen:
# res is a dict
logging.debug(f"res: {res}")
finish_reason = res["meta_info"]["finish_reason"]
if is_batch:
# Handle batch response - get index from SGLang response
index = res.get("index", 0)
if index not in num_output_tokens_so_far:
num_output_tokens_so_far[index] = 0
if finish_reason:
logging.warning(f"finish_reason: {finish_reason}")
# Final response for this batch item
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
"index": index,
}
else:
# Streaming response for this batch item
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far[index] :]
out = {
"token_ids": new_tokens,
"index": index,
}
num_output_tokens_so_far[index] = next_total_toks
else:
if finish_reason:
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
}
else:
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far:]
out = {
"token_ids": new_tokens,
}
num_output_tokens_so_far = next_total_toks
yield out
async def encode(self, request):
obj = EmbeddingReqInput(input_ids=request["token_ids"])
generator = self.engine_client.tokenizer_manager.generate_request(obj, None)
engine_results = await anext(generator)
tokens = 0
embeddings = []
for result in engine_results:
embeddings.append(result["embedding"])
tokens += result["meta_info"]["prompt_tokens"]
out = {
"embeddings": embeddings,
"prompt_tokens": tokens,
"total_tokens": tokens,
}
yield out
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
arg_map = {
"model_path": config.model_path,
"skip_tokenizer_init": True,
"tp_size": config.tensor_parallel_size,
"base_gpu_id": config.base_gpu_id,
}
if config.kv_block_size:
arg_map["page_size"] = config.kv_block_size
if config.context_length:
arg_map["context_length"] = config.context_length
if config.dist_init_addr != "":
arg_map["trust_remote_code"] = True
arg_map["nnodes"] = config.nnodes
arg_map["dist_init_addr"] = config.dist_init_addr
# In practice this is always 0 because Dynamo only manages the leader
arg_map["node_rank"] = config.node_rank
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
# TODO fetch default SamplingParams from generation_config.json
engine_args = ServerArgs(**arg_map)
engine_client = sglang.Engine(server_args=engine_args)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
model_type = (
ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
)
await register_llm(
model_type,
endpoint,
config.model_path,
config.model_name,
migration_limit=config.migration_limit,
)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
handler = RequestHandler(engine_client)
if engine_args.is_embedding:
await endpoint.serve_endpoint(handler.encode)
else:
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="SGLang server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--base-gpu-id",
type=int,
default=0,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--nnodes", type=int, default=1, help="The number of machines SGLang will use"
)
parser.add_argument(
"--node-rank",
type=int,
default=0,
help="Unique number for each node. 0 for the leader.",
)
parser.add_argument(
"--dist-init-addr",
type=str,
default="",
help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the SGLang Engine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.base_gpu_id = args.base_gpu_id
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.nnodes = args.nnodes
config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// Source code of the TRTLLM sub-process
pub const PY: &str = include_str!("trtllm_inc.py");
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This is a sample config for TensorRT-LLM engine.
# The config provides smaller free_gpu_memory_fraction to ensure that the engine
# does not use all the GPU memory and both prefill and decode workers can fit in
# the GPU memory when running in disaggregated mode.
# You might have to tweak this config based on your model size and GPU memory.
backend: pytorch
disable_overlap_scheduler: true
kv_cache_config:
free_gpu_memory_fraction: 0.40
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# TODO:
# - Support disaggregated serving
# - Update examples to use this engine.
#
# `dynamo-run out=trtllm` runs this script
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params
#
# Disaggregated serving:
# - Ingress: dynamo run in=http out=dyn
# - Decode Worker: python3 trtllm_inc.py --task=decode --extra-engine-args=trtllm_config/sample.yaml
# - Prefill Worker: python3 trtllm_inc.py --task=prefill --extra-engine-args=trtllm_config/sample.yaml
import argparse
import asyncio
import base64
import copy
import logging
import sys
import warnings
from dataclasses import asdict, dataclass
from typing import Optional
import uvloop
# Import TRTLLM and related modules
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.llm import (
ModelType,
get_tensorrtllm_engine,
get_tensorrtllm_publisher,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
# Qwen/Qwen3-0.6B is not supported by TRTLLM yet.
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Default endpoint for the remote prefill service.
DEFAULT_PREFILL_ENDPOINT = "dyn://dynamo.prefill.generate"
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
configure_dynamo_logging()
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
raise ValueError(
f"Invalid endpoint format: '{endpoint}'. "
"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
return tuple(endpoint_parts)
class DisaggregatedParamsCodec:
"""
Codec for encoding and decoding disaggregated params for network transfer.
"""
@staticmethod
def decode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
opaque_state = (
base64.b64decode(disaggregated_params.opaque_state)
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
@staticmethod
def encode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
encoded_opaque_state = (
base64.b64encode(disaggregated_params.opaque_state).decode("utf-8")
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=encoded_opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str] = None
tensor_parallel_size: int
kv_block_size: int
migration_limit: int
extra_engine_args: str
publish_events_and_metrics: bool
disaggregation_mode: str
remote_prefill_endpoint: str
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"model_name={self.model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, "
f"kv_block_size={self.kv_block_size}, "
f"migration_limit={self.migration_limit}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"remote_prefill_endpoint={self.remote_prefill_endpoint})"
)
@dataclass
class RequestHandlerConfig:
"""
Configuration for the request handler
"""
component: object
engine: object
default_sampling_params: object
publisher: object
disaggregation_mode: str
remote_prefill_client: object
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, config: RequestHandlerConfig):
self.engine = config.engine
self.component = config.component
self.default_sampling_params = config.default_sampling_params
self.publisher = config.publisher
self.disaggregation_mode = config.disaggregation_mode
self.remote_prefill_client = config.remote_prefill_client
self.first_generation = True
async def remote_prefill(self, request):
"""
Send a prefill request to the remote prefill worker.
Args:
request: The original request to be sent for prefill
Returns:
The response from the remote prefill worker
Raises:
ValueError: If prefill client is not initialized or multiple responses received
"""
prefill_request = copy.deepcopy(request)
# TRTLLM requires max_tokens to be set for prefill requests.
prefill_request["stop_conditions"]["max_tokens"] = 1
# Set the disaggregated params to context_only for remote prefill
prefill_request["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(
DisaggregatedParams(request_type="context_only")
)
)
if self.remote_prefill_client is None:
raise ValueError("Prefill client not initialized")
try:
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
remote_prefill_responses = [
remote_prefill_response
async for remote_prefill_response in await self.remote_prefill_client.round_robin(
prefill_request
)
]
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
if len(remote_prefill_responses) > 1:
raise ValueError(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
)
if len(remote_prefill_responses) == 0:
raise ValueError("No response received from remote prefill worker")
remote_prefill_response = remote_prefill_responses[0]
return remote_prefill_response
async def generate(self, request):
# Check if there is an error in the publisher error queue
publishers_error = (
self.publisher.check_error_queue() if self.publisher else None
)
if publishers_error:
raise publishers_error
inputs = request["token_ids"]
# Decode the disaggregated params from the request
if "disaggregated_params" in request:
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**request["disaggregated_params"])
)
else:
disaggregated_params = None
num_output_tokens_so_far = 0
if self.disaggregation_mode == "decode":
# Run prefill/context phase remotely if disaggregation mode is decode.
try:
prefill_result = await self.remote_prefill(request)
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
remote_prefill_response = prefill_result.data()
if (
remote_prefill_response["finish_reason"] == "stop"
or remote_prefill_response["finish_reason"] == "error"
):
yield remote_prefill_response
return
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
# Decode the disaggregated params from the remote prefill response
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**remote_prefill_response["disaggregated_params"])
)
# Send the first token response to the client
first_token_response = remote_prefill_response
first_token_response.pop("disaggregated_params")
yield first_token_response
# Set the disaggregated params to generation_only for the rest of the generation
disaggregated_params.request_type = "generation_only"
sampling_params = self.default_sampling_params
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
# TODO: Disable streaming for context only requests when adding disagg support
async for res in self.engine.llm.generate_async(
inputs=inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=(self.disaggregation_mode != "prefill"),
):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publisher:
self.publisher.start()
self.first_generation = False
if res.finished and self.disaggregation_mode != "prefill":
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == "prefill":
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
logging.info(f"Initializing the worker with config: {config}")
remote_prefill_client = None
if config.disaggregation_mode == "decode":
logging.info(
f"Initializing remote prefill client for endpoint: {config.remote_prefill_endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
config.remote_prefill_endpoint
)
remote_prefill_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model_path)
arg_map = {
"model": model_path,
"tensor_parallel_size": config.tensor_parallel_size,
"backend": "pytorch",
"skip_tokenizer_init": True,
}
if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
if config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config = None
if "kv_cache_config" not in arg_map:
kv_cache_config = {}
kv_cache_config["event_buffer_max_size"] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
else:
kv_cache_config = arg_map["kv_cache_config"]
if not kv_cache_config.event_buffer_max_size:
kv_cache_config.event_buffer_max_size = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
arg_map["kv_cache_config"] = kv_cache_config
# Only pytorch backend is supported for now to publish events and metrics.
if "backend" not in arg_map:
arg_map["backend"] = "pytorch"
elif arg_map["backend"] != "pytorch":
logging.error(
"Only pytorch backend is supported for now to publish events and metrics."
)
sys.exit(1)
logging.info(f"TRTLLM engine args: {arg_map}")
engine_args = arg_map
# Populate default sampling params from the model
tokenizer = tokenizer_factory(arg_map["model"])
default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
async with get_tensorrtllm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
if config.disaggregation_mode != "prefill":
# Register the model with the endpoint if disaggregation mode is not prefill.
# Prefill worker will get the request directly from the Decode worker and not
# through the ingress.
# FIXME: Enable publishing events and metrics for disaggregated prefill.
# Currently prefill workers are chosen in round-robin fashion.
await register_llm(
ModelType.Backend,
endpoint,
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
)
# publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig(
component=component,
engine=engine,
default_sampling_params=default_sampling_params,
publisher=None,
disaggregation_mode=config.disaggregation_mode,
remote_prefill_client=remote_prefill_client,
)
if (
config.publish_events_and_metrics
and config.disaggregation_mode != "prefill"
):
# Initialize and pass in the publisher to the request handler to
# publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component(
config.component
)
async with get_tensorrtllm_publisher(
component,
engine,
kv_listener,
int(endpoint.lease_id()),
config.kv_block_size,
) as publisher:
handler_config.publisher = publisher
handler = RequestHandler(handler_config)
await endpoint.serve_endpoint(handler.generate)
else:
handler = RequestHandler(handler_config)
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="TensorRT-LLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
# IMPORTANT: We should ideally not expose this to users. We should be able to
# query the block size from the TRTLLM engine.
parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
parser.add_argument(
"--publish-events-and-metrics",
action="store_true",
help="Publish events and metrics to the dynamo components. Note: This is not supported when running in prefill disaggregation mode.",
)
parser.add_argument(
"--task",
type=str,
action="append",
choices=["prefill", "decode", "prefill_and_decode"],
default=[],
help="Specifies the task for the engine. Can be specified multiple time for different tasks. Will raise an error if conflicting tasks are specified.",
)
parser.add_argument(
"--remote-prefill-endpoint",
type=str,
default=DEFAULT_PREFILL_ENDPOINT,
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send prefill requests to when running in decode disaggregation mode. Default: {DEFAULT_PREFILL_ENDPOINT}",
)
args = parser.parse_args()
# Validate arguments
if args.context_length is not None:
warnings.warn(
"--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
UserWarning,
)
endpoint = args.endpoint
# disaggregation mode
disaggregation_mode = None
for choice in ["prefill", "decode", "prefill_and_decode"]:
if choice in args.task:
if disaggregation_mode is not None:
raise ValueError(
f"Conflicting tasks specified: {args.task}. Please specify only one task."
)
disaggregation_mode = choice
if disaggregation_mode is None:
disaggregation_mode = "prefill_and_decode"
if disaggregation_mode == "prefill":
if args.remote_prefill_endpoint != DEFAULT_PREFILL_ENDPOINT:
logging.error(
"--remote-prefill-endpoint is not supported when running in prefill disaggregation mode."
)
sys.exit(1)
else:
endpoint = DEFAULT_PREFILL_ENDPOINT
if args.publish_events_and_metrics:
warnings.warn(
"--publish-events-and-metrics is not supported when running in prefill disaggregation mode.",
UserWarning,
)
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
config.disaggregation_mode = disaggregation_mode
config.remote_prefill_endpoint = args.remote_prefill_endpoint
return config
if __name__ == "__main__":
uvloop.install()
try:
asyncio.run(worker())
except KeyboardInterrupt:
logging.info("Received SIGINT, shutting down...")
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// Source code of the VLLM sub-process
pub const PY: &str = include_str!("vllm_inc.py");
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=vllm` runs this script
# Can also be used standalone: `python3 vllm_inc.py` - lots of optional cmd line params
# Setup checklist:
# - We are in a virtualenv with vllm installed - and patched if using kv routing.
# - `libdynamo_llm_capi.so` is in system lib path or it's containing folder is in LD_LIBRARY_PATH
# It builds in target/debug/ by default.
import argparse
import asyncio
import json
import logging
import os
import sys
import uuid
from typing import Optional
import uvloop
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs import TokensPrompt
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sdk.lib.utils import get_capi_library_path
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
configure_dynamo_logging()
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
tensor_parallel_size: int
kv_block_size: int
context_length: int
migration_limit: int
extra_engine_args: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, component, engine, default_sampling_params):
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.metrics_publisher = WorkerMetricsPublisher()
def setup_kv_metrics(self):
if not hasattr(self.engine_client, "set_metrics_publisher"):
logging.debug("VLLM version does not support KV metrics")
return
self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
# Create the structured metrics objects
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=1024,
num_requests_waiting=0,
data_parallel_rank=None,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=1024,
gpu_cache_usage_perc=0.0,
gpu_prefix_cache_hit_rate=0.0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats, kv_stats=kv_stats, spec_decode_stats=None
)
# Publish the metrics as a single object
self.metrics_publisher.publish(metrics)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
)
async def create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
await self.metrics_publisher.create_endpoint(self.component)
async def generate(self, request):
# logging.debug(f"Received request: {request}")
request_id = str(uuid.uuid4().hex)
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = SamplingParams(**self.default_sampling_params)
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
num_output_tokens_so_far = 0
gen = self.engine_client.generate(prompt, sampling_params, request_id)
async for res in gen:
# res is vllm's RequestOutput
# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
def _check_and_set_env_value(key, expected, allow_override=False):
if not allow_override and key in os.environ and os.environ[key] != expected:
raise ValueError(
f"{key} is set and doesn't equal expected {expected}. Please unset variable before launch."
)
os.environ.setdefault(key, expected)
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
arg_map = {
"model": config.model_path,
"task": "generate",
"tensor_parallel_size": config.tensor_parallel_size,
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
}
assert config.kv_block_size > 0, "Must use non-negative integer for KV Block Size"
arg_map["block_size"] = config.kv_block_size
if config.context_length:
# Usually we want it to default to the max (from tokenizer_config.json)
arg_map["max_model_len"] = config.context_length
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
# Patch won't start KVCacheEventManager unless these four are set
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
_check_and_set_env_value("VLLM_WORKER_ID", str(endpoint.lease_id()))
_check_and_set_env_value(
"VLLM_KV_CAPI_PATH", get_capi_library_path(), allow_override=True
)
_check_and_set_env_value("VLLM_KV_NAMESPACE", config.namespace)
_check_and_set_env_value("VLLM_KV_COMPONENT", config.component)
_check_and_set_env_value(
"VLLM_NO_USAGE_STATS", "1", allow_override=True
) # Avoid internal HTTP requests
engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config()
# Load default sampling params from `generation_config.json`
default_sampling_params = model_config.get_diff_sampling_param()
engine_context = build_async_engine_client_from_engine_args(engine_args)
engine_client = await engine_context.__aenter__()
await register_llm(
ModelType.Backend,
endpoint,
config.model_path,
config.model_name,
context_length=arg_map.get(
"max_model_len", None
), # if None, takes length from tokenizer
kv_cache_block_size=arg_map["block_size"],
migration_limit=config.migration_limit,
)
handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics()
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="vLLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the vLLM AsyncLLMEngine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=vllm` runs this script
# Can also be used standalone: `python3 vllm_inc.py` - lots of optional cmd line params
# Setup checklist:
# - We are in a virtualenv with vllm installed. V1 is compatible with v0.9.0
# Steps:
# git clone https://github.com/vllm-project/vllm.git
# cd vllm && git checkout v0.9.0
# uv pip uninstall ai-dynamo-vllm
# VLLM_USE_PRECOMPILED=1 uv pip install --editable .
import argparse
import asyncio
import json
import logging
import os
import sys
import uuid
from typing import Optional
import uvloop
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVEventsConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelType,
SpecDecodeStats,
WorkerMetricsPublisher,
WorkerStats,
ZmqKvEventPublisher,
ZmqKvEventPublisherConfig,
register_llm,
)
from dynamo.runtime import Component, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
tensor_parallel_size: int
kv_block_size: int
context_length: int
migration_limit: int
extra_engine_args: str
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, component: Component, dp_rank: int) -> None:
self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component)
self.dp_rank = dp_rank
def record(
self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats]
):
# request_total_slots and kv_total_blocks are properties of model + gpu
# we should only publish them once, not every metric update
# they should be part of some runtime metadata tied to MDC or put in etcd ?
hit_rate = 0
if scheduler_stats.prefix_cache_stats.queries > 0:
hit_rate = (
scheduler_stats.prefix_cache_stats.hits
/ scheduler_stats.prefix_cache_stats.queries
)
worker_stats = WorkerStats(
request_active_slots=scheduler_stats.num_running_reqs,
request_total_slots=0, # TODO - remove from metrics
num_requests_waiting=scheduler_stats.num_waiting_reqs,
data_parallel_rank=None,
)
kv_stats = KvStats(
kv_active_blocks=0, # TODO - need to calculate this
kv_total_blocks=0, # TODO - remove from metrics
gpu_cache_usage_perc=scheduler_stats.gpu_cache_usage, # used in current cost function
gpu_prefix_cache_hit_rate=hit_rate,
)
spec_dec_stats = scheduler_stats.spec_decoding_stats
if spec_dec_stats:
spec_dec_stats = SpecDecodeStats(
num_spec_tokens=spec_dec_stats.num_spec_tokens,
num_drafts=spec_dec_stats.num_drafts,
num_draft_tokens=spec_dec_stats.num_draft_tokens,
num_accepted_tokens=spec_dec_stats.num_accepted_tokens,
num_accepted_tokens_per_pos=spec_dec_stats.num_accepted_tokens_per_pos,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=spec_dec_stats,
)
self.inner.publish(metrics)
def log_engine_initialized(self) -> None:
pass
class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, component: Component) -> None:
self.component = component
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
return DynamoStatLoggerPublisher(self.component, dp_rank)
def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase:
return self.create_stat_logger(dp_rank=dp_rank)
class RequestHandler:
"""
Request handler for the generate and clear_kv_blocks endpoints.
"""
def __init__(self, component, engine, default_sampling_params):
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
async def clear_kv_blocks(self, request=None):
try:
await self.engine_client.reset_prefix_cache()
yield {"status": "success", "message": "KV cache cleared"}
except Exception as e:
yield {"status": "error", "message": str(e)}
async def generate(self, request):
request_id = str(uuid.uuid4().hex)
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = SamplingParams(**self.default_sampling_params)
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
num_output_tokens_so_far = 0
gen = self.engine_client.generate(prompt, sampling_params, request_id)
async for res in gen:
# res is vllm's RequestOutput
# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
await register_llm(
ModelType.Backend,
generate_endpoint,
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
)
arg_map = {
"model": config.model_path,
"task": "generate",
"tensor_parallel_size": config.tensor_parallel_size,
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True, publisher="zmq"
),
}
if config.context_length:
# Usually we want it to default to the max (from tokenizer_config.json)
arg_map["max_model_len"] = config.context_length
if config.kv_block_size > 0:
arg_map["block_size"] = config.kv_block_size
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
logger.info(f"VLLM config: {arg_map}")
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
os.environ[
"VLLM_WORKER_MULTIPROC_METHOD"
] = "spawn" # Ensure our publisher makes it to the new process
engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config()
# Load default sampling params from `generation_config.json`
default_sampling_params = model_config.get_diff_sampling_param()
# Taken from build_async_engine_client_from_engine_args()
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# Explicitly pass our custom stat logger for metrics
engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=[StatLoggerFactory(component)],
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
)
logger.info("VllmWorker has been initialized")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(), kv_block_size=engine_args.block_size
)
_ = ZmqKvEventPublisher(component=component, config=zmq_config)
handler = RequestHandler(component, engine_client, default_sampling_params)
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
def cmd_line_args():
parser = argparse.ArgumentParser(
description="vLLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the vLLM AsyncLLMEngine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -22,7 +22,6 @@
import argparse
import asyncio
import signal
import sys
from pathlib import Path
......@@ -31,9 +30,6 @@ import uvloop
from dynamo.llm import EngineType, EntrypointArgs, make_engine, run_input
from dynamo.runtime import DistributedRuntime
subprocess_ref = None # Global process reference for cleanup
subprocess_task = None # Global async task reference for cleanup
def parse_args():
in_mode = "text"
......@@ -90,13 +86,6 @@ def parse_args():
# http_port: Option<u16>
parser.add_argument("--http-port", type=int, help="HTTP port for the engine (u16).")
# TODO: Not yet used here
parser.add_argument(
"--tensor-parallel-size",
type=int,
help="Tensor parallel size for the model (e.g., 4).",
)
# Add the positional model argument.
# It's made optional (nargs='?') because its requirement depends on 'out_mode',
# which is handled in post-parsing validation.
......@@ -131,39 +120,8 @@ def parse_args():
return parsed_args
async def cleanup_subprocess_async():
"""Clean up the sglang/vllm/trtllm subprocess if it exists."""
global subprocess_ref
if subprocess_ref and subprocess_ref.returncode is None:
subprocess_ref.terminate()
try:
await asyncio.wait_for(subprocess_ref.wait(), timeout=2)
except asyncio.TimeoutError:
subprocess_ref.kill()
await subprocess_ref.wait()
# Only cleanup once
subprocess_ref = None
def signal_handler():
"""Handle signals in async context by cleaning up subprocess and exiting."""
asyncio.create_task(cleanup_subprocess_async())
sys.exit(0)
async def run():
global subprocess_ref
global subprocess_task
# Register signal handlers
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGINT, signal_handler) # Ctrl-C
loop.add_signal_handler(signal.SIGTERM, signal_handler) # kill
# If we find cases where subprocess does not stop we may need this. Seem OK so far.
# atexit.register(cleanup_subprocess)
runtime = DistributedRuntime(loop, False)
args = parse_args()
......@@ -174,57 +132,6 @@ async def run():
}
out_mode = args["out_mode"]
# Handle subprocess execution for sglang and vllm
if out_mode in ["sglang", "vllm", "trtllm"]:
# Determine which script to run
script_name = f"{out_mode}_inc.py"
script_path = Path(__file__).parent / script_name
if not script_path.exists():
print(f"Error: Script '{script_path}' not found")
sys.exit(1)
# Build command with all relevant arguments
cmd = [sys.executable, str(script_path)]
# Add arguments if they exist
if args["model_path"]:
cmd.extend(["--model-path", args["model_path"]])
flags = args["flags"]
if flags.model_name:
cmd.extend(["--model-name", flags.model_name])
if flags.context_length:
cmd.extend(["--context-length", str(flags.context_length)])
if flags.kv_cache_block_size:
cmd.extend(["--kv-cache-block-size", str(flags.kv_cache_block_size)])
# Start subprocess in background and stream output
print(f"Starting {out_mode} subprocess: {' '.join(cmd)}")
async def stream_subprocess_output():
global subprocess_ref
subprocess_ref = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
)
try:
if subprocess_ref.stdout is not None:
async for line in subprocess_ref.stdout:
print(f"Engine: {line.decode().rstrip()}")
await subprocess_ref.wait()
except asyncio.CancelledError:
# Task was cancelled, terminate the subprocess
await cleanup_subprocess_async()
raise
task = asyncio.create_task(stream_subprocess_output())
# Store the task reference for potential cleanup
subprocess_task = task
# Set out_mode to "dyn" because we talk to the subprocess over NATS
out_mode = "dyn"
engine_type = engine_type_map.get(out_mode)
if engine_type is None:
print(f"Unsupported output type: {out_mode}")
......@@ -249,19 +156,7 @@ async def run():
e = EntrypointArgs(engine_type, **entrypoint_kwargs)
engine = await make_engine(runtime, e)
try:
await run_input(runtime, args["in_mode"], engine)
finally:
# Clean up subprocess when main execution finishes
await cleanup_subprocess_async()
# Cancel the subprocess task if it exists
if subprocess_task:
subprocess_task.cancel()
try:
await subprocess_task
except asyncio.CancelledError:
pass
if __name__ == "__main__":
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=sglang` runs this script
# Can also be used standalone: `python3 sglang_inc.py` - lots of optional cmd line params
import argparse
import asyncio
import json
import logging
import sys
from typing import Optional
import sglang
import uvloop
from sglang.srt.entrypoints.engine import EmbeddingReqInput
from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
configure_dynamo_logging()
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
base_gpu_id: int
tensor_parallel_size: int
kv_block_size: int
context_length: int
nnodes: int
node_rank: int
dist_init_addr: str
extra_engine_args: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine):
self.engine_client = engine
async def generate(self, request):
sampling_params = {}
if request["sampling_options"]["temperature"] is not None:
sampling_params["temperature"] = request["sampling_options"]["temperature"]
# sglang defaults this to 128
sampling_params["max_new_tokens"] = request["stop_conditions"]["max_tokens"]
# Check if this is a batch request
is_batch = "batch_token_ids" in request and request["batch_token_ids"]
if is_batch:
# Track tokens separately for each batch item
num_output_tokens_so_far = {}
gen = await self.engine_client.async_generate(
input_ids=request["batch_token_ids"],
sampling_params=sampling_params,
stream=True,
)
else:
num_output_tokens_so_far = 0
gen = await self.engine_client.async_generate(
input_ids=request["token_ids"],
sampling_params=sampling_params,
stream=True,
)
async for res in gen:
# res is a dict
finish_reason = res["meta_info"]["finish_reason"]
if is_batch:
# Handle batch response - get index from SGLang response
index = res.get("index", 0)
if index not in num_output_tokens_so_far:
num_output_tokens_so_far[index] = 0
if finish_reason:
logging.warning(f"finish_reason: {finish_reason}")
# Final response for this batch item
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
"index": index,
}
else:
# Streaming response for this batch item
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far[index] :]
out = {
"token_ids": new_tokens,
"index": index,
}
num_output_tokens_so_far[index] = next_total_toks
else:
if finish_reason:
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
}
else:
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far:]
out = {
"token_ids": new_tokens,
}
num_output_tokens_so_far = next_total_toks
yield out
async def encode(self, request):
obj = EmbeddingReqInput(input_ids=request["token_ids"])
generator = self.engine_client.tokenizer_manager.generate_request(obj, None)
engine_results = await anext(generator)
tokens = 0
embeddings = []
for result in engine_results:
embeddings.append(result["embedding"])
tokens += result["meta_info"]["prompt_tokens"]
out = {
"embeddings": embeddings,
"prompt_tokens": tokens,
"total_tokens": tokens,
}
yield out
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
arg_map = {
"model_path": config.model_path,
"skip_tokenizer_init": True,
"tp_size": config.tensor_parallel_size,
"base_gpu_id": config.base_gpu_id,
}
if config.kv_block_size:
arg_map["page_size"] = config.kv_block_size
if config.context_length:
arg_map["context_length"] = config.context_length
if config.dist_init_addr != "":
arg_map["trust_remote_code"] = True
arg_map["nnodes"] = config.nnodes
arg_map["dist_init_addr"] = config.dist_init_addr
# In practice this is always 0 because Dynamo only manages the leader
arg_map["node_rank"] = config.node_rank
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
# TODO fetch default SamplingParams from generation_config.json
engine_args = ServerArgs(**arg_map)
engine_client = sglang.Engine(server_args=engine_args)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
model_type = (
ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
)
await register_llm(model_type, endpoint, config.model_path, config.model_name)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
handler = RequestHandler(engine_client)
if engine_args.is_embedding:
await endpoint.serve_endpoint(handler.encode)
else:
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="SGLang server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--base-gpu-id",
type=int,
default=0,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--nnodes", type=int, default=1, help="The number of machines SGLang will use"
)
parser.add_argument(
"--node-rank",
type=int,
default=0,
help="Unique number for each node. 0 for the leader.",
)
parser.add_argument(
"--dist-init-addr",
type=str,
default="",
help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the SGLang Engine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.base_gpu_id = args.base_gpu_id
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.nnodes = args.nnodes
config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# TODO:
# - Support disaggregated serving
# - Update examples to use this engine.
#
# `dynamo-run out=trtllm` runs this script
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params
#
# Disaggregated serving:
# - Ingress: dynamo run in=http out=dyn
# - Decode Worker: python3 trtllm_inc.py --task=decode --extra-engine-args=trtllm_config/sample.yaml
# - Prefill Worker: python3 trtllm_inc.py --task=prefill --extra-engine-args=trtllm_config/sample.yaml
import argparse
import asyncio
import base64
import copy
import logging
import sys
import warnings
from dataclasses import asdict, dataclass
from typing import Optional
import uvloop
# Import TRTLLM and related modules
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.llm import (
ModelType,
get_tensorrtllm_engine,
get_tensorrtllm_publisher,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
# Qwen/Qwen3-0.6B is not supported by TRTLLM yet.
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Default endpoint for the remote prefill service.
DEFAULT_PREFILL_ENDPOINT = "dyn://dynamo.prefill.generate"
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
configure_dynamo_logging()
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
raise ValueError(
f"Invalid endpoint format: '{endpoint}'. "
"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
return tuple(endpoint_parts)
class DisaggregatedParamsCodec:
"""
Codec for encoding and decoding disaggregated params for network transfer.
"""
@staticmethod
def decode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
opaque_state = (
base64.b64decode(disaggregated_params.opaque_state)
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
@staticmethod
def encode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
encoded_opaque_state = (
base64.b64encode(disaggregated_params.opaque_state).decode("utf-8")
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=encoded_opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str] = None
tensor_parallel_size: int
kv_block_size: int
extra_engine_args: str
publish_events_and_metrics: bool
disaggregation_mode: str
remote_prefill_endpoint: str
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"model_name={self.model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, "
f"kv_block_size={self.kv_block_size}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"remote_prefill_endpoint={self.remote_prefill_endpoint})"
)
@dataclass
class RequestHandlerConfig:
"""
Configuration for the request handler
"""
component: object
engine: object
default_sampling_params: object
publisher: object
disaggregation_mode: str
remote_prefill_client: object
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, config: RequestHandlerConfig):
self.engine = config.engine
self.component = config.component
self.default_sampling_params = config.default_sampling_params
self.publisher = config.publisher
self.disaggregation_mode = config.disaggregation_mode
self.remote_prefill_client = config.remote_prefill_client
self.first_generation = True
async def remote_prefill(self, request):
"""
Send a prefill request to the remote prefill worker.
Args:
request: The original request to be sent for prefill
Returns:
The response from the remote prefill worker
Raises:
ValueError: If prefill client is not initialized or multiple responses received
"""
prefill_request = copy.deepcopy(request)
# TRTLLM requires max_tokens to be set for prefill requests.
prefill_request["stop_conditions"]["max_tokens"] = 1
# Set the disaggregated params to context_only for remote prefill
prefill_request["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(
DisaggregatedParams(request_type="context_only")
)
)
if self.remote_prefill_client is None:
raise ValueError("Prefill client not initialized")
try:
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
remote_prefill_responses = [
remote_prefill_response
async for remote_prefill_response in await self.remote_prefill_client.round_robin(
prefill_request
)
]
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
if len(remote_prefill_responses) > 1:
raise ValueError(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
)
if len(remote_prefill_responses) == 0:
raise ValueError("No response received from remote prefill worker")
remote_prefill_response = remote_prefill_responses[0]
return remote_prefill_response
async def generate(self, request):
# Check if there is an error in the publisher error queue
publishers_error = (
self.publisher.check_error_queue() if self.publisher else None
)
if publishers_error:
raise publishers_error
inputs = request["token_ids"]
# Decode the disaggregated params from the request
if "disaggregated_params" in request:
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**request["disaggregated_params"])
)
else:
disaggregated_params = None
num_output_tokens_so_far = 0
if self.disaggregation_mode == "decode":
# Run prefill/context phase remotely if disaggregation mode is decode.
try:
prefill_result = await self.remote_prefill(request)
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
remote_prefill_response = prefill_result.data()
if (
remote_prefill_response["finish_reason"] == "stop"
or remote_prefill_response["finish_reason"] == "error"
):
yield remote_prefill_response
return
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
# Decode the disaggregated params from the remote prefill response
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**remote_prefill_response["disaggregated_params"])
)
# Send the first token response to the client
first_token_response = remote_prefill_response
first_token_response.pop("disaggregated_params")
yield first_token_response
# Set the disaggregated params to generation_only for the rest of the generation
disaggregated_params.request_type = "generation_only"
sampling_params = self.default_sampling_params
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
# TODO: Disable streaming for context only requests when adding disagg support
async for res in self.engine.llm.generate_async(
inputs=inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=(self.disaggregation_mode != "prefill"),
):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publisher:
self.publisher.start()
self.first_generation = False
if res.finished and self.disaggregation_mode != "prefill":
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == "prefill":
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
logging.info(f"Initializing the worker with config: {config}")
remote_prefill_client = None
if config.disaggregation_mode == "decode":
logging.info(
f"Initializing remote prefill client for endpoint: {config.remote_prefill_endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
config.remote_prefill_endpoint
)
remote_prefill_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model_path)
arg_map = {
"model": model_path,
"tensor_parallel_size": config.tensor_parallel_size,
"backend": "pytorch",
"skip_tokenizer_init": True,
}
if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
if config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config = None
if "kv_cache_config" not in arg_map:
kv_cache_config = {}
kv_cache_config["event_buffer_max_size"] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
else:
kv_cache_config = arg_map["kv_cache_config"]
if "event_buffer_max_size" not in kv_cache_config:
kv_cache_config[
"event_buffer_max_size"
] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
arg_map["kv_cache_config"] = kv_cache_config
# Only pytorch backend is supported for now to publish events and metrics.
if "backend" not in arg_map:
arg_map["backend"] = "pytorch"
elif arg_map["backend"] != "pytorch":
logging.error(
"Only pytorch backend is supported for now to publish events and metrics."
)
sys.exit(1)
logging.info(f"TRTLLM engine args: {arg_map}")
engine_args = arg_map
# Populate default sampling params from the model
tokenizer = tokenizer_factory(arg_map["model"])
default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
async with get_tensorrtllm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
if config.disaggregation_mode != "prefill":
# Register the model with the endpoint if disaggregation mode is not prefill.
# Prefill worker will get the request directly from the Decode worker and not
# through the ingress.
# FIXME: Enable publishing events and metrics for disaggregated prefill.
# Currently prefill workers are chosen in round-robin fashion.
await register_llm(
ModelType.Backend,
endpoint,
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
)
# publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig(
component=component,
engine=engine,
default_sampling_params=default_sampling_params,
publisher=None,
disaggregation_mode=config.disaggregation_mode,
remote_prefill_client=remote_prefill_client,
)
if (
config.publish_events_and_metrics
and config.disaggregation_mode != "prefill"
):
# Initialize and pass in the publisher to the request handler to
# publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component(
config.component
)
async with get_tensorrtllm_publisher(
component,
engine,
kv_listener,
int(endpoint.lease_id()),
config.kv_block_size,
) as publisher:
handler_config.publisher = publisher
handler = RequestHandler(handler_config)
await endpoint.serve_endpoint(handler.generate)
else:
handler = RequestHandler(handler_config)
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="TensorRT-LLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
# IMPORTANT: We should ideally not expose this to users. We should be able to
# query the block size from the TRTLLM engine.
parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
parser.add_argument(
"--publish-events-and-metrics",
action="store_true",
help="Publish events and metrics to the dynamo components. Note: This is not supported when running in prefill disaggregation mode.",
)
parser.add_argument(
"--task",
type=str,
action="append",
choices=["prefill", "decode", "prefill_and_decode"],
default=[],
help="Specifies the task for the engine. Can be specified multiple time for different tasks. Will raise an error if conflicting tasks are specified.",
)
parser.add_argument(
"--remote-prefill-endpoint",
type=str,
default=DEFAULT_PREFILL_ENDPOINT,
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send prefill requests to when running in decode disaggregation mode. Default: {DEFAULT_PREFILL_ENDPOINT}",
)
args = parser.parse_args()
# Validate arguments
if args.context_length is not None:
warnings.warn(
"--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
UserWarning,
)
endpoint = args.endpoint
# disaggregation mode
disaggregation_mode = None
for choice in ["prefill", "decode", "prefill_and_decode"]:
if choice in args.task:
if disaggregation_mode is not None:
raise ValueError(
f"Conflicting tasks specified: {args.task}. Please specify only one task."
)
disaggregation_mode = choice
if disaggregation_mode is None:
disaggregation_mode = "prefill_and_decode"
if disaggregation_mode == "prefill":
if args.remote_prefill_endpoint != DEFAULT_PREFILL_ENDPOINT:
logging.error(
"--remote-prefill-endpoint is not supported when running in prefill disaggregation mode."
)
sys.exit(1)
else:
endpoint = DEFAULT_PREFILL_ENDPOINT
if args.publish_events_and_metrics:
warnings.warn(
"--publish-events-and-metrics is not supported when running in prefill disaggregation mode.",
UserWarning,
)
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
config.disaggregation_mode = disaggregation_mode
config.remote_prefill_endpoint = args.remote_prefill_endpoint
return config
if __name__ == "__main__":
uvloop.install()
try:
asyncio.run(worker())
except KeyboardInterrupt:
logging.info("Received SIGINT, shutting down...")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=vllm` runs this script
# Can also be used standalone: `python3 vllm_inc.py` - lots of optional cmd line params
# Setup checklist:
# - We are in a virtualenv with vllm installed - and patched if using kv routing.
# - `libdynamo_llm_capi.so` is in system lib path or it's containing folder is in LD_LIBRARY_PATH
# It builds in target/debug/ by default.
import argparse
import asyncio
import json
import logging
import os
import sys
import uuid
from typing import Optional
import uvloop
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs import TokensPrompt
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
configure_dynamo_logging()
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
tensor_parallel_size: int
kv_block_size: int
context_length: int
extra_engine_args: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, component, engine, default_sampling_params):
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.metrics_publisher = WorkerMetricsPublisher()
def setup_kv_metrics(self):
if not hasattr(self.engine_client, "set_metrics_publisher"):
logging.debug("VLLM version does not support KV metrics")
return
self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
# Create the structured metrics objects
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=1024,
num_requests_waiting=0,
data_parallel_rank=None,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=1024,
gpu_cache_usage_perc=0.0,
gpu_prefix_cache_hit_rate=0.0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats, kv_stats=kv_stats, spec_decode_stats=None
)
# Publish the metrics as a single object
self.metrics_publisher.publish(metrics)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
)
async def create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
await self.metrics_publisher.create_endpoint(self.component)
async def generate(self, request):
# logging.debug(f"Received request: {request}")
request_id = str(uuid.uuid4().hex)
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = SamplingParams(**self.default_sampling_params)
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
num_output_tokens_so_far = 0
gen = self.engine_client.generate(prompt, sampling_params, request_id)
async for res in gen:
# res is vllm's RequestOutput
# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
def _check_and_set_env_value(key, expected, allow_override=False):
if not allow_override and key in os.environ and os.environ[key] != expected:
raise ValueError(
f"{key} is set and doesn't equal expected {expected}. Please unset variable before launch."
)
os.environ.setdefault(key, expected)
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
arg_map = {
"model": config.model_path,
"task": "generate",
"tensor_parallel_size": config.tensor_parallel_size,
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
}
assert config.kv_block_size > 0, "Must use non-negative integer for KV Block Size"
arg_map["block_size"] = config.kv_block_size
if config.context_length:
# Usually we want it to default to the max (from tokenizer_config.json)
arg_map["max_model_len"] = config.context_length
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
# Patch won't start KVCacheEventManager unless these four are set
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
_check_and_set_env_value("VLLM_WORKER_ID", str(endpoint.lease_id()))
_check_and_set_env_value(
"VLLM_KV_CAPI_PATH", "libdynamo_llm_capi.so", allow_override=True
)
_check_and_set_env_value("VLLM_KV_NAMESPACE", config.namespace)
_check_and_set_env_value("VLLM_KV_COMPONENT", config.component)
_check_and_set_env_value(
"VLLM_NO_USAGE_STATS", "1", allow_override=True
) # Avoid internal HTTP requests
engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config()
# Load default sampling params from `generation_config.json`
default_sampling_params = model_config.get_diff_sampling_param()
engine_context = build_async_engine_client_from_engine_args(engine_args)
engine_client = await engine_context.__aenter__()
await register_llm(
ModelType.Backend,
endpoint,
config.model_path,
config.model_name,
context_length=arg_map.get(
"max_model_len", None
), # if None, takes length from tokenizer
kv_cache_block_size=arg_map["block_size"],
)
handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics()
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="vLLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the vLLM AsyncLLMEngine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -80,10 +80,7 @@ pub async fn run(
(Box::pin(fut), Some(model.card().clone()))
}
EngineConfig::Dynamic(_) => {
// We can only get here for in=dyn out=vllm|sglang`, because vllm and sglang are a
// subprocess that we talk to like a remote endpoint.
// That means the vllm/sglang subprocess is doing all the work, we are idle.
(never_ready(), None)
unreachable!("An endpoint input will never have a Dynamic engine");
}
};
......@@ -107,7 +104,3 @@ pub async fn run(
Ok(())
}
fn never_ready() -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> {
Box::pin(std::future::pending::<anyhow::Result<()>>())
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment