Unverified Commit 93acc631 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(cli.py): dynamo-run-in-python handle sglang, vllm and trtllm (#1832)

parent e756f390
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import argparse import argparse
import asyncio import asyncio
import signal
import sys import sys
from pathlib import Path from pathlib import Path
...@@ -15,6 +16,9 @@ import uvloop ...@@ -15,6 +16,9 @@ import uvloop
from dynamo.llm import EngineType, EntrypointArgs, make_engine, run_input from dynamo.llm import EngineType, EntrypointArgs, make_engine, run_input
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
subprocess_ref = None # Global process reference for cleanup
subprocess_task = None # Global async task reference for cleanup
def parse_args(): def parse_args():
in_mode = "text" in_mode = "text"
...@@ -111,8 +115,39 @@ def parse_args(): ...@@ -111,8 +115,39 @@ def parse_args():
return parsed_args return parsed_args
async def cleanup_subprocess_async():
"""Clean up the sglang/vllm/trtllm subprocess if it exists."""
global subprocess_ref
if subprocess_ref and subprocess_ref.returncode is None:
subprocess_ref.terminate()
try:
await asyncio.wait_for(subprocess_ref.wait(), timeout=2)
except asyncio.TimeoutError:
subprocess_ref.kill()
await subprocess_ref.wait()
# Only cleanup once
subprocess_ref = None
def signal_handler():
"""Handle signals in async context by cleaning up subprocess and exiting."""
asyncio.create_task(cleanup_subprocess_async())
sys.exit(0)
async def run(): async def run():
global subprocess_ref
global subprocess_task
# Register signal handlers
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGINT, signal_handler) # Ctrl-C
loop.add_signal_handler(signal.SIGTERM, signal_handler) # kill
# If we find cases where subprocess does not stop we may need this. Seem OK so far.
# atexit.register(cleanup_subprocess)
runtime = DistributedRuntime(loop, False) runtime = DistributedRuntime(loop, False)
args = parse_args() args = parse_args()
...@@ -124,13 +159,63 @@ async def run(): ...@@ -124,13 +159,63 @@ async def run():
"dyn": EngineType.Dynamic, "dyn": EngineType.Dynamic,
} }
out_mode = args["out_mode"] out_mode = args["out_mode"]
# Handle subprocess execution for sglang and vllm
if out_mode in ["sglang", "vllm", "trtllm"]:
# Determine which script to run
script_name = f"{out_mode}_inc.py"
script_path = Path(__file__).parent / script_name
if not script_path.exists():
print(f"Error: Script '{script_path}' not found")
sys.exit(1)
# Build command with all relevant arguments
cmd = [sys.executable, str(script_path)]
# Add arguments if they exist
if args["model_path"]:
cmd.extend(["--model-path", args["model_path"]])
flags = args["flags"]
if flags.model_name:
cmd.extend(["--model-name", flags.model_name])
if flags.context_length:
cmd.extend(["--context-length", str(flags.context_length)])
if flags.kv_cache_block_size:
cmd.extend(["--kv-cache-block-size", str(flags.kv_cache_block_size)])
# Start subprocess in background and stream output
print(f"Starting {out_mode} subprocess: {' '.join(cmd)}")
async def stream_subprocess_output():
global subprocess_ref
subprocess_ref = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
)
try:
if subprocess_ref.stdout is not None:
async for line in subprocess_ref.stdout:
print(f"Engine: {line.decode().rstrip()}")
await subprocess_ref.wait()
except asyncio.CancelledError:
# Task was cancelled, terminate the subprocess
await cleanup_subprocess_async()
raise
task = asyncio.create_task(stream_subprocess_output())
# Store the task reference for potential cleanup
subprocess_task = task
# Set out_mode to "dyn" because we talk to the subprocess over NATS
out_mode = "dyn"
engine_type = engine_type_map.get(out_mode) engine_type = engine_type_map.get(out_mode)
if engine_type is None: if engine_type is None:
print(f"Unsupported output type: {out_mode}") print(f"Unsupported output type: {out_mode}")
sys.exit(1) sys.exit(1)
# TODO: The "vllm", "sglang" and "trtllm" cases should call Python directly
entrypoint_kwargs = {"model_path": args["model_path"]} entrypoint_kwargs = {"model_path": args["model_path"]}
flags = args["flags"] flags = args["flags"]
...@@ -149,7 +234,20 @@ async def run(): ...@@ -149,7 +234,20 @@ async def run():
e = EntrypointArgs(engine_type, **entrypoint_kwargs) e = EntrypointArgs(engine_type, **entrypoint_kwargs)
engine = await make_engine(runtime, e) engine = await make_engine(runtime, e)
try:
await run_input(runtime, args["in_mode"], engine) await run_input(runtime, args["in_mode"], engine)
finally:
# Clean up subprocess when main execution finishes
await cleanup_subprocess_async()
# Cancel the subprocess task if it exists
if subprocess_task:
subprocess_task.cancel()
try:
await subprocess_task
except asyncio.CancelledError:
pass
if __name__ == "__main__": if __name__ == "__main__":
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=sglang` runs this script
# Can also be used standalone: `python3 sglang_inc.py` - lots of optional cmd line params
import argparse
import asyncio
import json
import logging
import sys
from typing import Optional
import sglang
import uvloop
from sglang.srt.entrypoints.engine import EmbeddingReqInput
from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
configure_dynamo_logging()
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
base_gpu_id: int
tensor_parallel_size: int
kv_block_size: int
context_length: int
nnodes: int
node_rank: int
dist_init_addr: str
extra_engine_args: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine):
self.engine_client = engine
async def generate(self, request):
sampling_params = {}
if request["sampling_options"]["temperature"] is not None:
sampling_params["temperature"] = request["sampling_options"]["temperature"]
# sglang defaults this to 128
sampling_params["max_new_tokens"] = request["stop_conditions"]["max_tokens"]
# Check if this is a batch request
is_batch = "batch_token_ids" in request and request["batch_token_ids"]
if is_batch:
# Track tokens separately for each batch item
num_output_tokens_so_far = {}
gen = await self.engine_client.async_generate(
input_ids=request["batch_token_ids"],
sampling_params=sampling_params,
stream=True,
)
else:
num_output_tokens_so_far = 0
gen = await self.engine_client.async_generate(
input_ids=request["token_ids"],
sampling_params=sampling_params,
stream=True,
)
async for res in gen:
# res is a dict
finish_reason = res["meta_info"]["finish_reason"]
if is_batch:
# Handle batch response - get index from SGLang response
index = res.get("index", 0)
if index not in num_output_tokens_so_far:
num_output_tokens_so_far[index] = 0
if finish_reason:
logging.warning(f"finish_reason: {finish_reason}")
# Final response for this batch item
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
"index": index,
}
else:
# Streaming response for this batch item
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far[index] :]
out = {
"token_ids": new_tokens,
"index": index,
}
num_output_tokens_so_far[index] = next_total_toks
else:
if finish_reason:
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
}
else:
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far:]
out = {
"token_ids": new_tokens,
}
num_output_tokens_so_far = next_total_toks
yield out
async def encode(self, request):
obj = EmbeddingReqInput(input_ids=request["token_ids"])
generator = self.engine_client.tokenizer_manager.generate_request(obj, None)
engine_results = await anext(generator)
tokens = 0
embeddings = []
for result in engine_results:
embeddings.append(result["embedding"])
tokens += result["meta_info"]["prompt_tokens"]
out = {
"embeddings": embeddings,
"prompt_tokens": tokens,
"total_tokens": tokens,
}
yield out
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
arg_map = {
"model_path": config.model_path,
"skip_tokenizer_init": True,
"tp_size": config.tensor_parallel_size,
"base_gpu_id": config.base_gpu_id,
}
if config.kv_block_size:
arg_map["page_size"] = config.kv_block_size
if config.context_length:
arg_map["context_length"] = config.context_length
if config.dist_init_addr != "":
arg_map["trust_remote_code"] = True
arg_map["nnodes"] = config.nnodes
arg_map["dist_init_addr"] = config.dist_init_addr
# In practice this is always 0 because Dynamo only manages the leader
arg_map["node_rank"] = config.node_rank
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
# TODO fetch default SamplingParams from generation_config.json
engine_args = ServerArgs(**arg_map)
engine_client = sglang.Engine(server_args=engine_args)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
model_type = (
ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
)
await register_llm(model_type, endpoint, config.model_path, config.model_name)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
handler = RequestHandler(engine_client)
if engine_args.is_embedding:
await endpoint.serve_endpoint(handler.encode)
else:
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="SGLang server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--base-gpu-id",
type=int,
default=0,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--nnodes", type=int, default=1, help="The number of machines SGLang will use"
)
parser.add_argument(
"--node-rank",
type=int,
default=0,
help="Unique number for each node. 0 for the leader.",
)
parser.add_argument(
"--dist-init-addr",
type=str,
default="",
help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the SGLang Engine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.base_gpu_id = args.base_gpu_id
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.nnodes = args.nnodes
config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# TODO:
# - Support disaggregated serving
# - Update examples to use this engine.
#
# `dynamo-run out=trtllm` runs this script
# Can be used standalone: `python3 trtllm_inc.py` - lots of optional cmd line params
#
# Disaggregated serving:
# - Ingress: dynamo run in=http out=dyn
# - Decode Worker: python3 trtllm_inc.py --task=decode --extra-engine-args=trtllm_config/sample.yaml
# - Prefill Worker: python3 trtllm_inc.py --task=prefill --extra-engine-args=trtllm_config/sample.yaml
import argparse
import asyncio
import base64
import copy
import logging
import sys
import warnings
from dataclasses import asdict, dataclass
from typing import Optional
import uvloop
# Import TRTLLM and related modules
from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import DisaggregatedParams
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from dynamo.llm import (
ModelType,
get_tensorrtllm_engine,
get_tensorrtllm_publisher,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
# Qwen/Qwen3-0.6B is not supported by TRTLLM yet.
DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Default endpoint for the remote prefill service.
DEFAULT_PREFILL_ENDPOINT = "dyn://dynamo.prefill.generate"
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
configure_dynamo_logging()
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
raise ValueError(
f"Invalid endpoint format: '{endpoint}'. "
"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
return tuple(endpoint_parts)
class DisaggregatedParamsCodec:
"""
Codec for encoding and decoding disaggregated params for network transfer.
"""
@staticmethod
def decode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
opaque_state = (
base64.b64decode(disaggregated_params.opaque_state)
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
@staticmethod
def encode(
disaggregated_params: DisaggregatedParams,
) -> DisaggregatedParams:
if disaggregated_params is None:
return None
encoded_opaque_state = (
base64.b64encode(disaggregated_params.opaque_state).decode("utf-8")
if disaggregated_params.opaque_state is not None
else None
)
return DisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=encoded_opaque_state,
draft_tokens=disaggregated_params.draft_tokens,
)
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str] = None
tensor_parallel_size: int
kv_block_size: int
extra_engine_args: str
publish_events_and_metrics: bool
disaggregation_mode: str
remote_prefill_endpoint: str
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"model_name={self.model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, "
f"kv_block_size={self.kv_block_size}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"remote_prefill_endpoint={self.remote_prefill_endpoint})"
)
@dataclass
class RequestHandlerConfig:
"""
Configuration for the request handler
"""
component: object
engine: object
default_sampling_params: object
publisher: object
disaggregation_mode: str
remote_prefill_client: object
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, config: RequestHandlerConfig):
self.engine = config.engine
self.component = config.component
self.default_sampling_params = config.default_sampling_params
self.publisher = config.publisher
self.disaggregation_mode = config.disaggregation_mode
self.remote_prefill_client = config.remote_prefill_client
self.first_generation = True
async def remote_prefill(self, request):
"""
Send a prefill request to the remote prefill worker.
Args:
request: The original request to be sent for prefill
Returns:
The response from the remote prefill worker
Raises:
ValueError: If prefill client is not initialized or multiple responses received
"""
prefill_request = copy.deepcopy(request)
# TRTLLM requires max_tokens to be set for prefill requests.
prefill_request["stop_conditions"]["max_tokens"] = 1
# Set the disaggregated params to context_only for remote prefill
prefill_request["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(
DisaggregatedParams(request_type="context_only")
)
)
if self.remote_prefill_client is None:
raise ValueError("Prefill client not initialized")
try:
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
remote_prefill_responses = [
remote_prefill_response
async for remote_prefill_response in await self.remote_prefill_client.round_robin(
prefill_request
)
]
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
if len(remote_prefill_responses) > 1:
raise ValueError(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
)
if len(remote_prefill_responses) == 0:
raise ValueError("No response received from remote prefill worker")
remote_prefill_response = remote_prefill_responses[0]
return remote_prefill_response
async def generate(self, request):
# Check if there is an error in the publisher error queue
publishers_error = (
self.publisher.check_error_queue() if self.publisher else None
)
if publishers_error:
raise publishers_error
inputs = request["token_ids"]
# Decode the disaggregated params from the request
if "disaggregated_params" in request:
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**request["disaggregated_params"])
)
else:
disaggregated_params = None
num_output_tokens_so_far = 0
if self.disaggregation_mode == "decode":
# Run prefill/context phase remotely if disaggregation mode is decode.
try:
prefill_result = await self.remote_prefill(request)
except Exception as e:
raise ValueError(f"Error in remote prefill: {e}")
remote_prefill_response = prefill_result.data()
if (
remote_prefill_response["finish_reason"] == "stop"
or remote_prefill_response["finish_reason"] == "error"
):
yield remote_prefill_response
return
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
# Decode the disaggregated params from the remote prefill response
disaggregated_params = DisaggregatedParamsCodec.decode(
DisaggregatedParams(**remote_prefill_response["disaggregated_params"])
)
# Send the first token response to the client
first_token_response = remote_prefill_response
first_token_response.pop("disaggregated_params")
yield first_token_response
# Set the disaggregated params to generation_only for the rest of the generation
disaggregated_params.request_type = "generation_only"
sampling_params = self.default_sampling_params
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
# TODO: Disable streaming for context only requests when adding disagg support
async for res in self.engine.llm.generate_async(
inputs=inputs,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=(self.disaggregation_mode != "prefill"),
):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publisher:
self.publisher.start()
self.first_generation = False
if res.finished and self.disaggregation_mode != "prefill":
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == "prefill":
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
logging.info(f"Initializing the worker with config: {config}")
remote_prefill_client = None
if config.disaggregation_mode == "decode":
logging.info(
f"Initializing remote prefill client for endpoint: {config.remote_prefill_endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
config.remote_prefill_endpoint
)
remote_prefill_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
# Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model_path)
arg_map = {
"model": model_path,
"tensor_parallel_size": config.tensor_parallel_size,
"backend": "pytorch",
"skip_tokenizer_init": True,
}
if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
if config.publish_events_and_metrics:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config = None
if "kv_cache_config" not in arg_map:
kv_cache_config = {}
kv_cache_config["event_buffer_max_size"] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
else:
kv_cache_config = arg_map["kv_cache_config"]
if "event_buffer_max_size" not in kv_cache_config:
kv_cache_config[
"event_buffer_max_size"
] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
arg_map["kv_cache_config"] = kv_cache_config
# Only pytorch backend is supported for now to publish events and metrics.
if "backend" not in arg_map:
arg_map["backend"] = "pytorch"
elif arg_map["backend"] != "pytorch":
logging.error(
"Only pytorch backend is supported for now to publish events and metrics."
)
sys.exit(1)
logging.info(f"TRTLLM engine args: {arg_map}")
engine_args = arg_map
# Populate default sampling params from the model
tokenizer = tokenizer_factory(arg_map["model"])
default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
async with get_tensorrtllm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
if config.disaggregation_mode != "prefill":
# Register the model with the endpoint if disaggregation mode is not prefill.
# Prefill worker will get the request directly from the Decode worker and not
# through the ingress.
# FIXME: Enable publishing events and metrics for disaggregated prefill.
# Currently prefill workers are chosen in round-robin fashion.
await register_llm(
ModelType.Backend,
endpoint,
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
)
# publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig(
component=component,
engine=engine,
default_sampling_params=default_sampling_params,
publisher=None,
disaggregation_mode=config.disaggregation_mode,
remote_prefill_client=remote_prefill_client,
)
if (
config.publish_events_and_metrics
and config.disaggregation_mode != "prefill"
):
# Initialize and pass in the publisher to the request handler to
# publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component(
config.component
)
async with get_tensorrtllm_publisher(
component,
engine,
kv_listener,
int(endpoint.lease_id()),
config.kv_block_size,
) as publisher:
handler_config.publisher = publisher
handler = RequestHandler(handler_config)
await endpoint.serve_endpoint(handler.generate)
else:
handler = RequestHandler(handler_config)
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="TensorRT-LLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
# IMPORTANT: We should ideally not expose this to users. We should be able to
# query the block size from the TRTLLM engine.
parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
)
parser.add_argument(
"--publish-events-and-metrics",
action="store_true",
help="Publish events and metrics to the dynamo components. Note: This is not supported when running in prefill disaggregation mode.",
)
parser.add_argument(
"--task",
type=str,
action="append",
choices=["prefill", "decode", "prefill_and_decode"],
default=[],
help="Specifies the task for the engine. Can be specified multiple time for different tasks. Will raise an error if conflicting tasks are specified.",
)
parser.add_argument(
"--remote-prefill-endpoint",
type=str,
default=DEFAULT_PREFILL_ENDPOINT,
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send prefill requests to when running in decode disaggregation mode. Default: {DEFAULT_PREFILL_ENDPOINT}",
)
args = parser.parse_args()
# Validate arguments
if args.context_length is not None:
warnings.warn(
"--context-length is accepted for compatibility but will be ignored for TensorRT-LLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
UserWarning,
)
endpoint = args.endpoint
# disaggregation mode
disaggregation_mode = None
for choice in ["prefill", "decode", "prefill_and_decode"]:
if choice in args.task:
if disaggregation_mode is not None:
raise ValueError(
f"Conflicting tasks specified: {args.task}. Please specify only one task."
)
disaggregation_mode = choice
if disaggregation_mode is None:
disaggregation_mode = "prefill_and_decode"
if disaggregation_mode == "prefill":
if args.remote_prefill_endpoint != DEFAULT_PREFILL_ENDPOINT:
logging.error(
"--remote-prefill-endpoint is not supported when running in prefill disaggregation mode."
)
sys.exit(1)
else:
endpoint = DEFAULT_PREFILL_ENDPOINT
if args.publish_events_and_metrics:
warnings.warn(
"--publish-events-and-metrics is not supported when running in prefill disaggregation mode.",
UserWarning,
)
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
config.disaggregation_mode = disaggregation_mode
config.remote_prefill_endpoint = args.remote_prefill_endpoint
return config
if __name__ == "__main__":
uvloop.install()
try:
asyncio.run(worker())
except KeyboardInterrupt:
logging.info("Received SIGINT, shutting down...")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=vllm` runs this script
# Can also be used standalone: `python3 vllm_inc.py` - lots of optional cmd line params
# Setup checklist:
# - We are in a virtualenv with vllm installed - and patched if using kv routing.
# - `libdynamo_llm_capi.so` is in system lib path or it's containing folder is in LD_LIBRARY_PATH
# It builds in target/debug/ by default.
import argparse
import asyncio
import json
import logging
import os
import sys
import uuid
from typing import Optional
import uvloop
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs import TokensPrompt
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
configure_dynamo_logging()
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
tensor_parallel_size: int
kv_block_size: int
context_length: int
extra_engine_args: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, component, engine, default_sampling_params):
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.metrics_publisher = WorkerMetricsPublisher()
def setup_kv_metrics(self):
if not hasattr(self.engine_client, "set_metrics_publisher"):
logging.debug("VLLM version does not support KV metrics")
return
self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
# Create the structured metrics objects
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=1024,
num_requests_waiting=0,
data_parallel_rank=None,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=1024,
gpu_cache_usage_perc=0.0,
gpu_prefix_cache_hit_rate=0.0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats, kv_stats=kv_stats, spec_decode_stats=None
)
# Publish the metrics as a single object
self.metrics_publisher.publish(metrics)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
)
async def create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
await self.metrics_publisher.create_endpoint(self.component)
async def generate(self, request):
# logging.debug(f"Received request: {request}")
request_id = str(uuid.uuid4().hex)
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = SamplingParams(**self.default_sampling_params)
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
num_output_tokens_so_far = 0
gen = self.engine_client.generate(prompt, sampling_params, request_id)
async for res in gen:
# res is vllm's RequestOutput
# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
def _check_and_set_env_value(key, expected, allow_override=False):
if not allow_override and key in os.environ and os.environ[key] != expected:
raise ValueError(
f"{key} is set and doesn't equal expected {expected}. Please unset variable before launch."
)
os.environ.setdefault(key, expected)
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
arg_map = {
"model": config.model_path,
"task": "generate",
"tensor_parallel_size": config.tensor_parallel_size,
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
}
assert config.kv_block_size > 0, "Must use non-negative integer for KV Block Size"
arg_map["block_size"] = config.kv_block_size
if config.context_length:
# Usually we want it to default to the max (from tokenizer_config.json)
arg_map["max_model_len"] = config.context_length
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
# Patch won't start KVCacheEventManager unless these four are set
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
_check_and_set_env_value("VLLM_WORKER_ID", str(endpoint.lease_id()))
_check_and_set_env_value(
"VLLM_KV_CAPI_PATH", "libdynamo_llm_capi.so", allow_override=True
)
_check_and_set_env_value("VLLM_KV_NAMESPACE", config.namespace)
_check_and_set_env_value("VLLM_KV_COMPONENT", config.component)
_check_and_set_env_value(
"VLLM_NO_USAGE_STATS", "1", allow_override=True
) # Avoid internal HTTP requests
engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config()
# Load default sampling params from `generation_config.json`
default_sampling_params = model_config.get_diff_sampling_param()
engine_context = build_async_engine_client_from_engine_args(engine_args)
engine_client = await engine_context.__aenter__()
await register_llm(
ModelType.Backend,
endpoint,
config.model_path,
config.model_name,
context_length=arg_map.get(
"max_model_len", None
), # if None, takes length from tokenizer
kv_cache_block_size=arg_map["block_size"],
)
handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics()
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="vLLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the vLLM AsyncLLMEngine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment