Unverified Commit 93acc631 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(cli.py): dynamo-run-in-python handle sglang, vllm and trtllm (#1832)

parent e756f390
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import argparse import argparse
import asyncio import asyncio
import signal
import sys import sys
from pathlib import Path from pathlib import Path
...@@ -15,6 +16,9 @@ import uvloop ...@@ -15,6 +16,9 @@ import uvloop
from dynamo.llm import EngineType, EntrypointArgs, make_engine, run_input from dynamo.llm import EngineType, EntrypointArgs, make_engine, run_input
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
subprocess_ref = None # Global process reference for cleanup
subprocess_task = None # Global async task reference for cleanup
def parse_args(): def parse_args():
in_mode = "text" in_mode = "text"
...@@ -111,8 +115,39 @@ def parse_args(): ...@@ -111,8 +115,39 @@ def parse_args():
return parsed_args return parsed_args
async def cleanup_subprocess_async():
"""Clean up the sglang/vllm/trtllm subprocess if it exists."""
global subprocess_ref
if subprocess_ref and subprocess_ref.returncode is None:
subprocess_ref.terminate()
try:
await asyncio.wait_for(subprocess_ref.wait(), timeout=2)
except asyncio.TimeoutError:
subprocess_ref.kill()
await subprocess_ref.wait()
# Only cleanup once
subprocess_ref = None
def signal_handler():
"""Handle signals in async context by cleaning up subprocess and exiting."""
asyncio.create_task(cleanup_subprocess_async())
sys.exit(0)
async def run(): async def run():
global subprocess_ref
global subprocess_task
# Register signal handlers
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGINT, signal_handler) # Ctrl-C
loop.add_signal_handler(signal.SIGTERM, signal_handler) # kill
# If we find cases where subprocess does not stop we may need this. Seem OK so far.
# atexit.register(cleanup_subprocess)
runtime = DistributedRuntime(loop, False) runtime = DistributedRuntime(loop, False)
args = parse_args() args = parse_args()
...@@ -124,13 +159,63 @@ async def run(): ...@@ -124,13 +159,63 @@ async def run():
"dyn": EngineType.Dynamic, "dyn": EngineType.Dynamic,
} }
out_mode = args["out_mode"] out_mode = args["out_mode"]
# Handle subprocess execution for sglang and vllm
if out_mode in ["sglang", "vllm", "trtllm"]:
# Determine which script to run
script_name = f"{out_mode}_inc.py"
script_path = Path(__file__).parent / script_name
if not script_path.exists():
print(f"Error: Script '{script_path}' not found")
sys.exit(1)
# Build command with all relevant arguments
cmd = [sys.executable, str(script_path)]
# Add arguments if they exist
if args["model_path"]:
cmd.extend(["--model-path", args["model_path"]])
flags = args["flags"]
if flags.model_name:
cmd.extend(["--model-name", flags.model_name])
if flags.context_length:
cmd.extend(["--context-length", str(flags.context_length)])
if flags.kv_cache_block_size:
cmd.extend(["--kv-cache-block-size", str(flags.kv_cache_block_size)])
# Start subprocess in background and stream output
print(f"Starting {out_mode} subprocess: {' '.join(cmd)}")
async def stream_subprocess_output():
global subprocess_ref
subprocess_ref = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
)
try:
if subprocess_ref.stdout is not None:
async for line in subprocess_ref.stdout:
print(f"Engine: {line.decode().rstrip()}")
await subprocess_ref.wait()
except asyncio.CancelledError:
# Task was cancelled, terminate the subprocess
await cleanup_subprocess_async()
raise
task = asyncio.create_task(stream_subprocess_output())
# Store the task reference for potential cleanup
subprocess_task = task
# Set out_mode to "dyn" because we talk to the subprocess over NATS
out_mode = "dyn"
engine_type = engine_type_map.get(out_mode) engine_type = engine_type_map.get(out_mode)
if engine_type is None: if engine_type is None:
print(f"Unsupported output type: {out_mode}") print(f"Unsupported output type: {out_mode}")
sys.exit(1) sys.exit(1)
# TODO: The "vllm", "sglang" and "trtllm" cases should call Python directly
entrypoint_kwargs = {"model_path": args["model_path"]} entrypoint_kwargs = {"model_path": args["model_path"]}
flags = args["flags"] flags = args["flags"]
...@@ -149,7 +234,20 @@ async def run(): ...@@ -149,7 +234,20 @@ async def run():
e = EntrypointArgs(engine_type, **entrypoint_kwargs) e = EntrypointArgs(engine_type, **entrypoint_kwargs)
engine = await make_engine(runtime, e) engine = await make_engine(runtime, e)
try:
await run_input(runtime, args["in_mode"], engine) await run_input(runtime, args["in_mode"], engine)
finally:
# Clean up subprocess when main execution finishes
await cleanup_subprocess_async()
# Cancel the subprocess task if it exists
if subprocess_task:
subprocess_task.cancel()
try:
await subprocess_task
except asyncio.CancelledError:
pass
if __name__ == "__main__": if __name__ == "__main__":
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=sglang` runs this script
# Can also be used standalone: `python3 sglang_inc.py` - lots of optional cmd line params
import argparse
import asyncio
import json
import logging
import sys
from typing import Optional
import sglang
import uvloop
from sglang.srt.entrypoints.engine import EmbeddingReqInput
from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
configure_dynamo_logging()
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
base_gpu_id: int
tensor_parallel_size: int
kv_block_size: int
context_length: int
nnodes: int
node_rank: int
dist_init_addr: str
extra_engine_args: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine):
self.engine_client = engine
async def generate(self, request):
sampling_params = {}
if request["sampling_options"]["temperature"] is not None:
sampling_params["temperature"] = request["sampling_options"]["temperature"]
# sglang defaults this to 128
sampling_params["max_new_tokens"] = request["stop_conditions"]["max_tokens"]
# Check if this is a batch request
is_batch = "batch_token_ids" in request and request["batch_token_ids"]
if is_batch:
# Track tokens separately for each batch item
num_output_tokens_so_far = {}
gen = await self.engine_client.async_generate(
input_ids=request["batch_token_ids"],
sampling_params=sampling_params,
stream=True,
)
else:
num_output_tokens_so_far = 0
gen = await self.engine_client.async_generate(
input_ids=request["token_ids"],
sampling_params=sampling_params,
stream=True,
)
async for res in gen:
# res is a dict
finish_reason = res["meta_info"]["finish_reason"]
if is_batch:
# Handle batch response - get index from SGLang response
index = res.get("index", 0)
if index not in num_output_tokens_so_far:
num_output_tokens_so_far[index] = 0
if finish_reason:
logging.warning(f"finish_reason: {finish_reason}")
# Final response for this batch item
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
"index": index,
}
else:
# Streaming response for this batch item
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far[index] :]
out = {
"token_ids": new_tokens,
"index": index,
}
num_output_tokens_so_far[index] = next_total_toks
else:
if finish_reason:
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
}
else:
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far:]
out = {
"token_ids": new_tokens,
}
num_output_tokens_so_far = next_total_toks
yield out
async def encode(self, request):
obj = EmbeddingReqInput(input_ids=request["token_ids"])
generator = self.engine_client.tokenizer_manager.generate_request(obj, None)
engine_results = await anext(generator)
tokens = 0
embeddings = []
for result in engine_results:
embeddings.append(result["embedding"])
tokens += result["meta_info"]["prompt_tokens"]
out = {
"embeddings": embeddings,
"prompt_tokens": tokens,
"total_tokens": tokens,
}
yield out
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
arg_map = {
"model_path": config.model_path,
"skip_tokenizer_init": True,
"tp_size": config.tensor_parallel_size,
"base_gpu_id": config.base_gpu_id,
}
if config.kv_block_size:
arg_map["page_size"] = config.kv_block_size
if config.context_length:
arg_map["context_length"] = config.context_length
if config.dist_init_addr != "":
arg_map["trust_remote_code"] = True
arg_map["nnodes"] = config.nnodes
arg_map["dist_init_addr"] = config.dist_init_addr
# In practice this is always 0 because Dynamo only manages the leader
arg_map["node_rank"] = config.node_rank
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
# TODO fetch default SamplingParams from generation_config.json
engine_args = ServerArgs(**arg_map)
engine_client = sglang.Engine(server_args=engine_args)
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
model_type = (
ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
)
await register_llm(model_type, endpoint, config.model_path, config.model_name)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
handler = RequestHandler(engine_client)
if engine_args.is_embedding:
await endpoint.serve_endpoint(handler.encode)
else:
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="SGLang server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--base-gpu-id",
type=int,
default=0,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--nnodes", type=int, default=1, help="The number of machines SGLang will use"
)
parser.add_argument(
"--node-rank",
type=int,
default=0,
help="Unique number for each node. 0 for the leader.",
)
parser.add_argument(
"--dist-init-addr",
type=str,
default="",
help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the SGLang Engine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.base_gpu_id = args.base_gpu_id
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.nnodes = args.nnodes
config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# `dynamo-run out=vllm` runs this script
# Can also be used standalone: `python3 vllm_inc.py` - lots of optional cmd line params
# Setup checklist:
# - We are in a virtualenv with vllm installed - and patched if using kv routing.
# - `libdynamo_llm_capi.so` is in system lib path or it's containing folder is in LD_LIBRARY_PATH
# It builds in target/debug/ by default.
import argparse
import asyncio
import json
import logging
import os
import sys
import uuid
from typing import Optional
import uvloop
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs import TokensPrompt
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
ModelType,
WorkerMetricsPublisher,
WorkerStats,
register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# Only used if you run it manually from the command line
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
configure_dynamo_logging()
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model_path: str
model_name: Optional[str]
tensor_parallel_size: int
kv_block_size: int
context_length: int
extra_engine_args: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, component, engine, default_sampling_params):
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.metrics_publisher = WorkerMetricsPublisher()
def setup_kv_metrics(self):
if not hasattr(self.engine_client, "set_metrics_publisher"):
logging.debug("VLLM version does not support KV metrics")
return
self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
# Create the structured metrics objects
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=1024,
num_requests_waiting=0,
data_parallel_rank=None,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=1024,
gpu_cache_usage_perc=0.0,
gpu_prefix_cache_hit_rate=0.0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats, kv_stats=kv_stats, spec_decode_stats=None
)
# Publish the metrics as a single object
self.metrics_publisher.publish(metrics)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logging.debug("metrics publisher endpoint created")
)
async def create_metrics_publisher_endpoint(self):
logging.debug("Creating metrics publisher endpoint")
await self.metrics_publisher.create_endpoint(self.component)
async def generate(self, request):
# logging.debug(f"Received request: {request}")
request_id = str(uuid.uuid4().hex)
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = SamplingParams(**self.default_sampling_params)
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request["stop_conditions"]["max_tokens"]
if max_tokens:
sampling_params.max_tokens = max_tokens
num_output_tokens_so_far = 0
gen = self.engine_client.generate(prompt, sampling_params, request_id)
async for res in gen:
# res is vllm's RequestOutput
# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
def _check_and_set_env_value(key, expected, allow_override=False):
if not allow_override and key in os.environ and os.environ[key] != expected:
raise ValueError(
f"{key} is set and doesn't equal expected {expected}. Please unset variable before launch."
)
os.environ.setdefault(key, expected)
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
arg_map = {
"model": config.model_path,
"task": "generate",
"tensor_parallel_size": config.tensor_parallel_size,
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
}
assert config.kv_block_size > 0, "Must use non-negative integer for KV Block Size"
arg_map["block_size"] = config.kv_block_size
if config.context_length:
# Usually we want it to default to the max (from tokenizer_config.json)
arg_map["max_model_len"] = config.context_length
if config.extra_engine_args != "":
json_map = {}
# extra_engine_args is a filename
try:
with open(config.extra_engine_args) as f:
json_map = json.load(f)
except FileNotFoundError:
logging.error(f"File {config.extra_engine_args} not found.")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
logging.debug(f"Adding extra engine arguments: {json_map}")
arg_map = {**arg_map, **json_map} # json_map gets precedence
# Patch won't start KVCacheEventManager unless these four are set
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
_check_and_set_env_value("VLLM_WORKER_ID", str(endpoint.lease_id()))
_check_and_set_env_value(
"VLLM_KV_CAPI_PATH", "libdynamo_llm_capi.so", allow_override=True
)
_check_and_set_env_value("VLLM_KV_NAMESPACE", config.namespace)
_check_and_set_env_value("VLLM_KV_COMPONENT", config.component)
_check_and_set_env_value(
"VLLM_NO_USAGE_STATS", "1", allow_override=True
) # Avoid internal HTTP requests
engine_args = AsyncEngineArgs(**arg_map)
model_config = engine_args.create_model_config()
# Load default sampling params from `generation_config.json`
default_sampling_params = model_config.get_diff_sampling_param()
engine_context = build_async_engine_client_from_engine_args(engine_args)
engine_client = await engine_context.__aenter__()
await register_llm(
ModelType.Backend,
endpoint,
config.model_path,
config.model_name,
context_length=arg_map.get(
"max_model_len", None
), # if None, takes length from tokenizer
kv_cache_block_size=arg_map["block_size"],
)
handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics()
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(handler.generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="vLLM server integrated with Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="Name to serve the model under. Defaults to deriving it from model path.",
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
)
parser.add_argument(
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
)
parser.add_argument(
"--context-length",
type=int,
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
default="",
help="Path to a JSON file containing additional keyword arguments to pass to the vLLM AsyncLLMEngine.",
)
args = parser.parse_args()
config = Config()
config.model_path = args.model_path
if args.model_name:
config.model_name = args.model_name
else:
# This becomes an `Option` on the Rust side
config.model_name = None
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logging.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.extra_engine_args = args.extra_engine_args
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment