"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "c91e2e4967f0a3b4c1612b4eb53f20efc0eb509a"
Unverified Commit da40db40 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

feat: add unified backend architecture with DynamoBackend (#8003)


Signed-off-by: default avatarTanmay Verma <tanmayv@nvidia.com>
parent f3b181a9
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
**/__pycache__ **/__pycache__
**/*.pyc **/*.pyc
**/*onnx* **/*onnx*
# Engine must be allowed because code contains dynamo_engine.py # Engine must be allowed because code contains llm_engine.py
**/*tensorrtllm_engines* **/*tensorrtllm_engines*
**/*tensorrtllm_models* **/*tensorrtllm_models*
**/*tensorrtllm_checkpoints* **/*tensorrtllm_checkpoints*
......
...@@ -263,7 +263,7 @@ jobs: ...@@ -263,7 +263,7 @@ jobs:
run_cpu_only_tests: true run_cpu_only_tests: true
cpu_only_test_markers: pre_merge and vllm and gpu_0 cpu_only_test_markers: pre_merge and vllm and gpu_0
gpu_test_markers: pre_merge and vllm and gpu_1 gpu_test_markers: pre_merge and vllm and gpu_1
gpu_test_timeout_minutes: 35 gpu_test_timeout_minutes: 45
secrets: inherit secrets: inherit
vllm-multi-gpu-test: vllm-multi-gpu-test:
......
# Backend Module
Two-class abstraction: `Worker` (runtime integration) and
`LLMEngine` (ABC for engine-specific logic). See `README.md` for full docs.
## Engine Lifecycle
```
from_args(argv) -> start() -> generate() / abort() -> cleanup()
| | | |
parse args, start engine, serve requests shutdown,
return config return metadata (concurrent) release resources
```
1. `from_args(argv)` -- classmethod factory. Parses CLI args, returns
`(engine, WorkerConfig)`. Engine is NOT started yet.
2. `start()` -- starts the engine, returns `EngineConfig`. After this returns
`generate()` MUST be ready to accept calls.
3. `generate(request, context)` -- streaming inference, called concurrently.
4. `abort(context)` -- cancel an in-flight request (optional, default no-op).
5. `cleanup()` -- called once on shutdown.
## Design Constraints
- **ZERO duplication across engine implementations.** This is the #1 priority.
The entire reason this module exists is to eliminate the code duplication
that grew across vllm, sglang, and trtllm. Before writing any logic inside
a `LLMEngine` subclass, check whether the same logic already exists in
another engine. If it does, extract it into `Worker` or a shared
utility and have all engines call the shared version.
When adding new features, always ask: "is this engine-specific or common?"
If two or more engines would need the same code, it is common.
- **Exactly two classes.** `Worker` owns runtime lifecycle.
`LLMEngine` owns inference. Do not add intermediate base classes or mixins.
- **`from_args()` returns `(engine, WorkerConfig)`.** The tuple return
makes the contract statically checkable -- a subclass that forgets to
build a `WorkerConfig` is a type error, not a runtime `AttributeError`.
- **`generate()` delegates to engine with cancellation monitoring.**
`Worker.generate()` runs a background task that watches
`context.async_killed_or_stopped()` and calls `engine.abort(context)` on
cancellation. It also checks `context.is_stopped()` after each yielded
chunk. Sampling params, prompt building, and output formatting stay inside
each engine -- they are deeply engine-specific.
- **`start()` returns `EngineConfig`.** The model class needs registration
metadata (`context_length`, `block_size`, `total_kv_blocks`) but must not
reach into engine internals. `start()` returns this metadata so the boundary
stays clean.
- **No hooks.** If behavior needs to be shared across engines, put it in
`Worker` or a shared utility, not in a hook system.
- **Parallel path.** The existing `main.py` / `worker_factory.py` / `init_llm.py`
entry points remain untouched. The `unified_main.py` files are a separate
path. Do not break or modify existing backends when changing this module.
## Request / Response Contract
`GenerateRequest` and `GenerateChunk` (`engine.py`) are `TypedDict`s that
type the `generate()` signature. `GenerateRequest` has `token_ids`
(required) plus optional `sampling_options`, `stop_conditions`, and
`output_options`. `GenerateChunk` has `token_ids` (required) plus
optional `finish_reason` and `completion_usage` (both required on the
final chunk). Engines may read/write additional keys — `TypedDict` does
not reject extras at runtime.
Build the `completion_usage` dict inline. Finish reason normalization
(e.g. `"abort"``"cancelled"`) is handled by the Rust layer.
## Adding a New Engine
1. Create `<backend>/llm_engine.py` subclassing `LLMEngine`
2. Implement `from_args()`, `start()`, `generate()`, `cleanup()` (required)
and `abort()` (optional)
3. `from_args()` must parse args and return `(engine, WorkerConfig)`
4. Create `<backend>/unified_main.py` calling `run(<YourEngine>)`
5. Use `sample_engine.py` as the reference implementation
## Error Handling
`Worker` wraps lifecycle and generate errors in
`DynamoException` subclasses (`dynamo.llm.exceptions`). The Rust bridge
(`engine.rs`) converts these into typed `DynamoError::Backend(...)` for
proper error chain observability. Engines can raise `DynamoException`
subclasses directly from `generate()` -- these pass through unchanged.
Non-`DynamoException` errors are wrapped as `Unknown`.
## Logging
Keep logging **standardized across all three engines** (vllm, sglang, trtllm).
When adding or changing a log message in one `llm_engine.py`, check
whether the same lifecycle event is logged in the other two and update them
to match. The goal is that operators see the same log shape regardless of
backend, making it easier to triage issues across mixed deployments.
Standardize on:
- `logger.info` for lifecycle milestones: engine init complete, serving
started, engine shutdown.
- `logger.debug` for per-request events: request abort, cancellation.
- `logger.warning` for recoverable problems: empty outputs, unexpected
finish reasons.
- `logger.error` only for unrecoverable failures.
## Key Files
| File | What it does |
|------|-------------|
| `engine.py` | `LLMEngine` ABC -- the only interface engines must implement |
| `worker.py` | `Worker` -- runtime lifecycle: create runtime, register model, serve endpoint, cleanup |
| `run.py` | Common entry point -- `run(engine_cls)` used by all `unified_main.py` files |
| `sample_engine.py` | Reference engine -- use as template and for testing |
# Dynamo Python Backend
> **Work in progress.** The unified backend currently supports minimal
> aggregated inference only. See [Feature Gaps](#feature-gaps) at the bottom
> for what remains to be implemented.
A two-class abstraction that separates **runtime integration** (common across
all backends) from **engine logic** (vLLM, SGLang, TensorRT-LLM, etc.).
## Architecture
```
LLMEngine (ABC) <-- engine boundary (engine.py)
| - from_args(argv) -> (LLMEngine, WorkerConfig) (factory)
| - start() -> EngineConfig (start engine, return metadata)
| - generate(request, context) (streaming inference)
| - abort(context) (cancel request, optional)
| - cleanup() (shutdown)
|
+-- VllmLLMEngine <-- vllm/llm_engine.py
+-- SglangLLMEngine <-- sglang/llm_engine.py
+-- TrtllmLLMEngine <-- trtllm/llm_engine.py
+-- SampleLLMEngine <-- sample_engine.py
Worker <-- runtime integration (worker.py)
- receives WorkerConfig from from_args()
- creates DistributedRuntime
- sets up endpoints, signal handlers
- calls engine.start(), registers model
- serves generate endpoint with cancellation monitoring
- calls engine.cleanup() on shutdown
```
## Quick Start
### Running the sample engine
```bash
python -m dynamo.common.backend.sample_main \
--model-name test-model \
--namespace dynamo \
--component sample \
--endpoint generate
```
This starts a backend that generates rotating token IDs. Point a frontend at
`dynamo.sample.generate` to test the full request flow without any ML
dependencies.
### Running a real engine
```bash
# vLLM
python -m dynamo.vllm.unified_main --model Qwen/Qwen3-0.6B ...
# SGLang
python -m dynamo.sglang.unified_main --model-path Qwen/Qwen3-0.6B ...
# TensorRT-LLM
python -m dynamo.trtllm.unified_main --model Qwen/Qwen3-0.6B ...
```
Each `unified_main.py` calls `run(MyLLMEngine)` from the common
`run.py` module.
## Implementing a New Engine
Subclass `LLMEngine` and implement the required methods:
```python
from dynamo.common.backend import LLMEngine, EngineConfig, WorkerConfig
class MyEngine(LLMEngine):
@classmethod
async def from_args(cls, argv=None):
# Parse CLI args, construct engine and worker_config.
engine = cls(...)
worker_config = WorkerConfig(
namespace="dynamo", component="my-backend", ...
)
return engine, worker_config
async def start(self) -> EngineConfig:
# Start the engine, return metadata for model registration.
# After this returns, generate() MUST be ready to accept calls.
return EngineConfig(
model="my-model",
context_length=4096,
kv_cache_block_size=16,
)
async def generate(self, request, context):
# Yield streaming response dicts.
async for result in my_engine.run(request):
yield {"token_ids": result.token_ids}
yield {
"token_ids": result.token_ids,
"finish_reason": "stop",
"completion_usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
async def abort(self, context):
# Cancel an in-flight request (optional, default is no-op).
await my_engine.cancel(context.id())
async def cleanup(self):
# Shut down the engine.
pass
```
Then create an entry point:
```python
# my_backend/unified_main.py
from dynamo.common.backend.run import run
from my_backend.llm_engine import MyEngine
def main():
run(MyEngine)
```
See `sample_engine.py` for a complete, runnable reference implementation.
## Request / Response Types
`GenerateRequest` and `GenerateChunk` (defined in `engine.py`) are
`TypedDict`s that document the shared fields across all engines.
```python
class GenerateRequest(TypedDict, total=False):
token_ids: Required[list[int]]
sampling_options: dict[str, Any]
stop_conditions: dict[str, Any]
output_options: dict[str, Any]
class GenerateChunk(TypedDict, total=False):
token_ids: Required[list[int]]
finish_reason: str # final chunk only
completion_usage: dict[str, int] # final chunk only
```
Engines may read additional backend-specific keys from the request dict
and write additional keys into response chunks — `TypedDict` does not
reject extra keys at runtime.
Build the `completion_usage` dict inline. Finish reason normalization
(e.g. `"abort"``"cancelled"`) is handled by the Rust layer.
## Request Cancellation
`Worker.generate()` automatically monitors for client
disconnections and request cancellations via `context.async_killed_or_stopped()`.
When triggered, it:
1. Calls `engine.abort(context)` to release engine resources (KV cache,
scheduler slots, etc.)
2. Breaks out of the generation loop
3. Cleans up the monitoring task
Engine implementations should override `abort(context)` to perform
backend-specific cleanup:
| Engine | Abort method | ID used |
|--------|-------------|---------|
| vLLM | `engine_client.abort(request_id)` | `context.id()` |
| SGLang | `tokenizer_manager.abort_request(rid=...)` | `context.trace_id` |
| TRT-LLM | `generation_result.abort()` | Tracked per-request via `context.id()` |
| Sample | *(no-op, default)* | — |
Engines that don't support cancellation can skip overriding `abort()`
the default implementation is a no-op. The generation loop will still
break on `context.is_stopped()`.
## Error Handling
`Worker` wraps errors in `DynamoException` subclasses from
`dynamo.llm.exceptions` so the Rust bridge can map them to typed
`DynamoError::Backend(...)` responses with proper error chains.
| Phase | Exception raised | When |
|-------|-----------------|------|
| Runtime creation | `CannotConnect` | etcd/NATS unreachable |
| Engine init | `EngineShutdown` | Engine fails to start (OOM, bad config, etc.) |
| Generate | `Unknown` | Untyped exception from engine `generate()` |
| Generate | *(pass-through)* | Engine raises a `DynamoException` subclass directly |
Engine implementations can raise `DynamoException` subclasses directly from
`generate()` for fine-grained error reporting — these propagate unchanged.
Any non-`DynamoException` errors are wrapped as `Unknown`.
Available exception types (from `dynamo.llm.exceptions`):
```python
from dynamo.llm.exceptions import (
DynamoException, # Base class
Unknown, # Uncategorized error
InvalidArgument, # Bad input (e.g., prompt too long)
CannotConnect, # Connection failed
Disconnected, # Connection lost
ConnectionTimeout, # Timeout
Cancelled, # Client cancelled
EngineShutdown, # Engine crashed or shutting down
StreamIncomplete, # Response stream cut short
)
```
## File Index
```
common/backend/
__init__.py # Re-exports: LLMEngine, EngineConfig,
# Worker, WorkerConfig
engine.py # LLMEngine ABC + EngineConfig dataclass
worker.py # Worker + WorkerConfig
run.py # Common entry point: run(engine_cls)
sample_engine.py # SampleLLMEngine (reference impl)
sample_main.py # Entry point for sample engine
vllm/llm_engine.py # VllmLLMEngine
vllm/unified_main.py # Entry point -> run(VllmLLMEngine)
sglang/llm_engine.py # SglangLLMEngine
sglang/unified_main.py # Entry point -> run(SglangLLMEngine)
trtllm/llm_engine.py # TrtllmLLMEngine
trtllm/unified_main.py # Entry point -> run(TrtllmLLMEngine)
```
## Feature Gaps
The unified path currently supports **minimal aggregated inference** only.
Below is a summary of what the existing (non-unified) backends provide that
the unified path does not yet support.
### What works today
- Basic aggregated token-in-token-out inference (all three engines)
- Model registration with endpoint types
- Request cancellation via `abort()` + `context.is_stopped()` monitoring
- `DynamoException` error chain wrapping
- Graceful shutdown with signal handling
- Finish reason normalization handled by Rust layer
### Common gaps (all engines)
| Feature | Description |
|---------|-------------|
| Disaggregated serving | Prefill/decode worker split, bootstrap coordination, KV transfer |
| Metrics & Prometheus | Engine-level metrics, KV cache utilization gauges, Prometheus multiprocess registry |
| KV event publishing | Prefix cache events (BlockStored/Removed) to router via ZMQ or NATS |
| Health check payloads | Per-engine custom health check payloads (BOS token probe, etc.) |
| Logprobs | Selected token + top-k log probability extraction and streaming |
| Guided decoding / structured outputs | JSON schema, regex, grammar, choice constraints |
| OpenTelemetry tracing | `build_trace_headers()`, request performance metrics, OTEL propagation |
| Engine routes | Profiling (start/stop), memory release/resume, weight update (disk/tensor/distributed/IPC) |
| Data-parallel routing | DP rank extraction from routing hints, DP-aware scheduling |
| Text-in-text-out mode | OpenAI-compatible chat/completion with engine-side tokenization |
| Custom Jinja chat templates | `--custom-jinja-template` for model-specific prompt formatting |
| Snapshot/checkpoint | CRIU-based engine state save/restore, identity reloading |
### vLLM-specific gaps
| Feature | Description |
|---------|-------------|
| LoRA adapters | Dynamic load/unload/list, ModelDeploymentCard publishing, per-LoRA serialization locks |
| Multimodal (images/video) | Image/video loading, embedding caching, NIXL RDMA transfer, Qwen VL mRoPE |
| Separate encode worker | `EncodeWorkerHandler` for multimodal encode-only disaggregation |
| Sleep/wake/quiesce | 3-level engine lifecycle control (weights, buffers, everything) |
| Elastic EP scaling | `scale_elastic_ep` with Ray node management |
| GMS shadow mode | GPU Memory Service integration with failover lock |
| ModelExpress P2P | Distributed model loading via P2P |
| KV block clearing | Prefix cache reset endpoint |
### SGLang-specific gaps
| Feature | Description |
|---------|-------------|
| Embedding inference | `async_encode()` path, OpenAI embedding response format |
| Image diffusion | `DiffGenerator` for text-to-image (FLUX, etc.) with TP/DP |
| Video generation | `DiffGenerator` for text-to-video (Wan2.1, etc.) |
| LLM diffusion (DLLM) | Diffusion language model algorithm support |
| Multimodal encode worker | Front-facing `MMEncoder`, embedding LRU cache, NIXL transfer |
| Multimodal worker | Aggregated/disaggregated multimodal inference with `EmbeddingsProcessor` |
| Deferred signal handling | Capturing SGLang's internal signal registrations for coordinated shutdown |
| Output modalities override | Required for diffusion workers (default `["text"]` -> `["image"]`/`["video"]`) |
### TRT-LLM-specific gaps
| Feature | Description |
|---------|-------------|
| Custom logits processors | `TrtllmDynamoLogitsAdapter` with CUDA stream support |
| Attention DP scheduling | `SchedulingParams` with `attention_dp_rank` and `attention_dp_relax` |
| Video diffusion | Auto-detect pipeline from `model_index.json`, MP4 encoding, MediaOutput |
| Multimodal processing | `MultimodalRequestProcessor`, image URL processing, embedding injection |
| Encode helper (EPD) | Remote encode via `encode_client`, NIXL tensor reading |
| KV cache connector | KVBM connector config, consolidator ZMQ integration |
| Fatal vs per-request errors | Distinguishing `RequestError` (recoverable) from fatal engine errors |
### Recommended migration order
1. **Metrics & health checks** -- needed for production observability
2. **Disaggregated serving** -- largest architectural change, unlocks PD split
3. **KV event publishing** -- required for KV-aware routing
4. **Logprobs + guided decoding** -- most-requested inference features
5. **Multimodal / LoRA / diffusion** -- modality-specific, can be parallelized across leads
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .engine import EngineConfig, GenerateChunk, GenerateRequest, LLMEngine
from .worker import Worker, WorkerConfig
__all__ = [
"EngineConfig",
"GenerateChunk",
"GenerateRequest",
"LLMEngine",
"Worker",
"WorkerConfig",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Required, TypedDict
from dynamo._core import Context
if TYPE_CHECKING:
from .worker import WorkerConfig
# ---------------------------------------------------------------------------
# Request / response contracts for generate()
#
# These TypedDicts document the shared fields that all engines read/write.
# Engine-specific keys (output_options, guided_decoding internals, etc.)
# flow through naturally — TypedDict doesn't reject extra keys at runtime.
# ---------------------------------------------------------------------------
class GenerateRequest(TypedDict, total=False):
"""Inbound request dict passed to ``LLMEngine.generate()``.
``token_ids`` is always present (set by the Rust preprocessor).
The remaining groups are optional — engines should access them
defensively with ``.get(key, {})``.
"""
token_ids: Required[list[int]]
sampling_options: dict[str, Any]
stop_conditions: dict[str, Any]
output_options: dict[str, Any]
class GenerateChunk(TypedDict, total=False):
"""Single chunk yielded by ``LLMEngine.generate()``.
Every chunk must include ``token_ids``.
The final chunk must additionally include ``finish_reason`` and
``completion_usage``.
"""
token_ids: Required[list[int]]
finish_reason: str
completion_usage: dict[str, int]
@dataclass
class EngineConfig:
model: str
served_model_name: Optional[str] = None
context_length: Optional[int] = None
kv_cache_block_size: Optional[int] = None
total_kv_blocks: Optional[int] = None
max_num_seqs: Optional[int] = None
max_num_batched_tokens: Optional[int] = None
class LLMEngine(ABC):
"""Abstract base for inference engines.
Lifecycle:
1. from_args(argv) -- parse CLI args, return (engine, WorkerConfig)
2. start() -- start the engine, return EngineConfig metadata.
After start() returns, generate() MUST be ready
to accept calls. Worker begins serving
immediately after start().
3. generate() -- called for each request (concurrent calls expected)
4. abort() -- called when a request is cancelled (optional, default no-op)
5. cleanup() -- called once on shutdown, release all resources
"""
@classmethod
@abstractmethod
async def from_args(
cls, argv: list[str] | None = None
) -> tuple[LLMEngine, WorkerConfig]:
"""Parse CLI args and construct the engine (not yet started).
Args:
argv: Command-line arguments. ``None`` means ``sys.argv[1:]``.
Returns:
A ``(engine, worker_config)`` pair.
"""
...
@abstractmethod
async def start(self) -> EngineConfig:
"""Start the engine and return registration metadata.
After this returns the engine MUST be ready to accept ``generate()``
calls. ``Worker`` will register the model and begin serving
immediately.
"""
...
@abstractmethod
async def generate(
self, request: GenerateRequest, context: Context
) -> AsyncGenerator[GenerateChunk, None]:
"""Yield streaming response chunks for a single request.
Called concurrently for multiple in-flight requests.
Each chunk: ``{"token_ids": [...]}``
Final chunk must include: ``{"token_ids": [...], "finish_reason": "...",
"completion_usage": {...}}``
"""
...
yield # type: ignore[misc]
async def abort(self, context: Context) -> None:
"""Abort an in-flight request (optional, default no-op).
Called by Worker when the client disconnects or
the request is cancelled. Override to release engine resources
(KV cache, scheduler slots, etc.).
"""
@abstractmethod
async def cleanup(self) -> None:
"""Release all engine resources. Called once on shutdown."""
...
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Common entry point for unified backends.
Each backend's ``unified_main.py`` calls :func:`run` with its
``LLMEngine`` subclass. Example::
from dynamo.common.backend.run import run
from dynamo.vllm.llm_engine import VllmLLMEngine
def main():
run(VllmLLMEngine)
"""
import uvloop
from .engine import LLMEngine
from .worker import Worker
async def _start(engine_cls: type[LLMEngine], argv: list[str] | None = None):
engine, worker_config = await engine_cls.from_args(argv)
w = Worker(engine, worker_config)
await w.run()
def run(engine_cls: type[LLMEngine], argv: list[str] | None = None):
"""Entry point for per-backend unified_main.py files."""
uvloop.run(_start(engine_cls, argv))
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import asyncio
from collections.abc import AsyncGenerator
from dynamo._core import Context
from .engine import EngineConfig, GenerateChunk, GenerateRequest, LLMEngine
from .worker import WorkerConfig
class SampleLLMEngine(LLMEngine):
"""Reference LLMEngine implementation.
Generates rotating token IDs with configurable per-token latency.
Useful for testing the Worker lifecycle end-to-end
and as a template for engine leads implementing real backends.
"""
def __init__(
self,
model_name: str = "sample-model",
max_tokens: int = 16,
delay: float = 0.01,
):
self.model_name = model_name
self.max_tokens = max_tokens
self.delay = delay
@classmethod
async def from_args(
cls, argv: list[str] | None = None
) -> tuple[SampleLLMEngine, WorkerConfig]:
parser = argparse.ArgumentParser(description="Sample Dynamo backend")
parser.add_argument("--model-name", default="sample-model")
parser.add_argument("--namespace", default="dynamo")
parser.add_argument("--component", default="sample")
parser.add_argument("--endpoint", default="generate")
parser.add_argument("--max-tokens", type=int, default=16)
parser.add_argument("--delay", type=float, default=0.01)
parser.add_argument("--endpoint-types", default="chat,completions")
parser.add_argument("--discovery-backend", default="etcd")
parser.add_argument("--request-plane", default="tcp")
parser.add_argument("--event-plane", default="nats")
args = parser.parse_args(argv)
engine = cls(
model_name=args.model_name,
max_tokens=args.max_tokens,
delay=args.delay,
)
worker_config = WorkerConfig(
namespace=args.namespace,
component=args.component,
endpoint=args.endpoint,
model_name=args.model_name,
served_model_name=args.model_name,
endpoint_types=args.endpoint_types,
discovery_backend=args.discovery_backend,
request_plane=args.request_plane,
event_plane=args.event_plane,
)
return engine, worker_config
async def start(self) -> EngineConfig:
return EngineConfig(
model=self.model_name,
served_model_name=self.model_name,
context_length=2048,
kv_cache_block_size=16,
total_kv_blocks=1000,
max_num_seqs=64,
max_num_batched_tokens=2048,
)
async def generate(
self, request: GenerateRequest, context: Context
) -> AsyncGenerator[GenerateChunk, None]:
token_ids = request.get("token_ids", [])
prompt_len = len(token_ids)
stop_conditions = request.get("stop_conditions", {})
max_new = stop_conditions.get("max_tokens") or self.max_tokens
for i in range(max_new):
if context.is_stopped():
yield {
"token_ids": [],
"finish_reason": "cancelled",
"completion_usage": {
"prompt_tokens": prompt_len,
"completion_tokens": i,
"total_tokens": prompt_len + i,
},
}
break
await asyncio.sleep(self.delay)
token_id = (i + 1) % 32000
out: GenerateChunk = {"token_ids": [token_id]}
if i == max_new - 1:
out["finish_reason"] = "length"
out["completion_usage"] = {
"prompt_tokens": prompt_len,
"completion_tokens": max_new,
"total_tokens": prompt_len + max_new,
}
yield out
async def cleanup(self) -> None:
pass
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Entry point for the sample backend.
Usage:
python -m dynamo.common.backend.sample_main --model-name test-model
"""
from dynamo.common.backend.run import run
from dynamo.common.backend.sample_engine import SampleLLMEngine
def main():
run(SampleLLMEngine)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Optional
from dynamo._core import Context
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.graceful_shutdown import install_signal_handlers
from dynamo.common.utils.runtime import create_runtime
from dynamo.llm import ModelInput, ModelRuntimeConfig, register_model
from dynamo.llm.exceptions import (
CannotConnect,
DynamoException,
EngineShutdown,
Unknown,
)
from dynamo.runtime.logging import configure_dynamo_logging
from .engine import GenerateChunk, GenerateRequest, LLMEngine
logger = logging.getLogger(__name__)
@dataclass
class WorkerConfig:
namespace: str
component: str = "backend"
endpoint: str = "generate"
model_name: str = ""
served_model_name: Optional[str] = None
model_input: ModelInput = field(default_factory=lambda: ModelInput.Tokens)
endpoint_types: str = "chat,completions"
discovery_backend: str = "etcd"
request_plane: str = "tcp"
event_plane: str = "nats"
use_kv_events: bool = False
custom_jinja_template: Optional[str] = None
metrics_labels: list = field(default_factory=list)
@classmethod
def from_runtime_config(
cls,
runtime_cfg,
model_name: str,
served_model_name: Optional[str] = None,
model_input: Optional[ModelInput] = None,
**overrides,
) -> "WorkerConfig":
"""Build from any object that carries DynamoRuntimeConfig fields.
Works with vllm.Config, trtllm.Config (inherit DynamoRuntimeConfig
directly) and sglang DynamoConfig (nested in config.dynamo_args).
"""
kwargs = {
"namespace": runtime_cfg.namespace,
"component": getattr(runtime_cfg, "component", None) or "backend",
"endpoint": getattr(runtime_cfg, "endpoint", None) or "generate",
"model_name": model_name,
"served_model_name": served_model_name,
"endpoint_types": getattr(
runtime_cfg, "endpoint_types", "chat,completions"
),
"discovery_backend": runtime_cfg.discovery_backend,
"request_plane": runtime_cfg.request_plane,
"event_plane": runtime_cfg.event_plane,
"use_kv_events": getattr(runtime_cfg, "use_kv_events", False),
"custom_jinja_template": getattr(
runtime_cfg, "custom_jinja_template", None
),
}
if model_input is not None:
kwargs["model_input"] = model_input
kwargs.update(overrides)
return cls(**kwargs)
class Worker:
def __init__(self, engine: LLMEngine, config: WorkerConfig):
self.config = config
self.engine = engine
async def generate(
self, request: GenerateRequest, context: Context
) -> AsyncGenerator[GenerateChunk, None]:
async def _monitor_cancel():
await context.async_killed_or_stopped()
try:
await self.engine.abort(context)
except Exception:
logger.debug("Error during request abort", exc_info=True)
cancel_task = asyncio.create_task(_monitor_cancel())
try:
async for chunk in self.engine.generate(request, context):
if context.is_stopped():
break
yield chunk
except DynamoException:
raise
except Exception as exc:
raise Unknown(f"Engine generate failed: {exc}") from exc
finally:
if not cancel_task.done():
cancel_task.cancel()
try:
await cancel_task
except asyncio.CancelledError:
pass
async def run(self) -> None:
configure_dynamo_logging()
cfg = self.config
shutdown_event = asyncio.Event()
try:
runtime, loop = create_runtime(
discovery_backend=cfg.discovery_backend,
request_plane=cfg.request_plane,
event_plane=cfg.event_plane,
use_kv_events=cfg.use_kv_events,
)
except DynamoException:
raise
except Exception as exc:
raise CannotConnect(f"Failed to create runtime: {exc}") from exc
endpoint = runtime.endpoint(f"{cfg.namespace}.{cfg.component}.{cfg.endpoint}")
shutdown_endpoints = [endpoint]
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
try:
engine_config = await self.engine.start()
except DynamoException:
raise
except Exception as exc:
raise EngineShutdown(f"Engine initialization failed: {exc}") from exc
try:
runtime_config = ModelRuntimeConfig()
if engine_config.total_kv_blocks is not None:
runtime_config.total_kv_blocks = engine_config.total_kv_blocks
if engine_config.max_num_seqs is not None:
runtime_config.max_num_seqs = engine_config.max_num_seqs
if engine_config.max_num_batched_tokens is not None:
runtime_config.max_num_batched_tokens = (
engine_config.max_num_batched_tokens
)
model_type = parse_endpoint_types(cfg.endpoint_types)
served_name = cfg.served_model_name or cfg.model_name
await register_model(
cfg.model_input,
model_type,
endpoint,
cfg.model_name,
served_name,
context_length=engine_config.context_length,
kv_cache_block_size=engine_config.kv_cache_block_size,
runtime_config=runtime_config,
custom_template_path=cfg.custom_jinja_template,
)
logger.info(
"Serving %s on %s.%s.%s",
served_name,
cfg.namespace,
cfg.component,
cfg.endpoint,
)
await endpoint.serve_endpoint(
self.generate,
graceful_shutdown=True,
metrics_labels=cfg.metrics_labels,
)
finally:
await self.engine.cleanup()
logger.info("Engine cleanup complete")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""SGLang LLMEngine implementation for the unified backend.
See dynamo/common/backend/README.md for architecture, response contract,
and feature gap details.
"""
from __future__ import annotations
import logging
import sys
from collections.abc import AsyncGenerator
import sglang as sgl
from dynamo._core import Context
from dynamo.common.backend.engine import (
EngineConfig,
GenerateChunk,
GenerateRequest,
LLMEngine,
)
from dynamo.common.backend.worker import WorkerConfig
from dynamo.common.utils.input_params import InputParamManager
from dynamo.llm import ModelInput
from dynamo.sglang.args import parse_args
logger = logging.getLogger(__name__)
class SglangLLMEngine(LLMEngine):
def __init__(self, server_args):
self.server_args = server_args
self.engine = None
self._input_param_manager = None
self._skip_tokenizer_init = server_args.skip_tokenizer_init
@classmethod
async def from_args(
cls, argv: list[str] | None = None
) -> tuple[SglangLLMEngine, WorkerConfig]:
config = await parse_args(argv if argv is not None else sys.argv[1:])
server_args = config.server_args
dynamo_args = config.dynamo_args
model_input = (
ModelInput.Text
if not server_args.skip_tokenizer_init
else ModelInput.Tokens
)
engine = cls(server_args)
worker_config = WorkerConfig.from_runtime_config(
dynamo_args,
model_name=server_args.model_path,
served_model_name=server_args.served_model_name,
model_input=model_input,
)
return engine, worker_config
async def start(self) -> EngineConfig:
self.engine = sgl.Engine(server_args=self.server_args)
tokenizer = (
self.engine.tokenizer_manager.tokenizer
if not self._skip_tokenizer_init
else None
)
self._input_param_manager = InputParamManager(tokenizer)
# Capacity fields -- sourced the same way as register.py in the
# non-unified path so the Rust runtime gets consistent values.
total_kv_blocks = None
scheduler_info = getattr(self.engine, "scheduler_info", None) or {}
max_total_tokens = scheduler_info.get("max_total_num_tokens")
page_size = self.server_args.page_size
if max_total_tokens and page_size:
total_kv_blocks = (max_total_tokens + page_size - 1) // page_size
return EngineConfig(
model=self.server_args.model_path,
served_model_name=self.server_args.served_model_name,
context_length=self.server_args.context_length,
kv_cache_block_size=page_size,
total_kv_blocks=total_kv_blocks,
max_num_seqs=getattr(self.server_args, "max_running_requests", None),
max_num_batched_tokens=getattr(
self.server_args, "max_prefill_tokens", None
),
)
async def generate(
self, request: GenerateRequest, context: Context
) -> AsyncGenerator[GenerateChunk, None]:
assert self.engine is not None, "Engine not initialized"
sampling_params = self._build_sampling_params(request)
input_param = self._get_input_param(request)
stream = await self.engine.async_generate(
**input_param,
sampling_params=sampling_params,
stream=True,
rid=context.trace_id,
)
async for res in stream:
out: GenerateChunk = {"token_ids": []}
meta_info = res["meta_info"]
finish_reason = meta_info["finish_reason"]
output_ids = res.get("output_ids", [])
if not output_ids and not finish_reason:
if context.is_stopped():
prompt_tokens = meta_info.get("prompt_tokens", 0)
completion_tokens = meta_info.get("completion_tokens", 0)
yield {
"token_ids": [],
"finish_reason": "cancelled",
"completion_usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
break
continue
out["token_ids"] = output_ids
if finish_reason:
prompt_tokens = meta_info["prompt_tokens"]
completion_tokens = meta_info["completion_tokens"]
out["finish_reason"] = finish_reason["type"]
out["completion_usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
if context.is_stopped():
prompt_tokens = meta_info.get("prompt_tokens", 0)
completion_tokens = meta_info.get("completion_tokens", 0)
yield {
"token_ids": output_ids,
"finish_reason": "cancelled",
"completion_usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
break
yield out
async def abort(self, context: Context) -> None:
rid = context.trace_id
if self.engine is not None and rid is not None:
if (
hasattr(self.engine, "tokenizer_manager")
and self.engine.tokenizer_manager
):
self.engine.tokenizer_manager.abort_request(rid=rid, abort_all=False)
logger.debug("Aborted request %s", rid)
async def cleanup(self) -> None:
if self.engine is not None:
self.engine.shutdown()
logger.info("SGLang engine shutdown")
def _build_sampling_params(self, request: GenerateRequest) -> dict:
if self._skip_tokenizer_init:
sampling_opts = request.get("sampling_options", {})
stop_conditions = request.get("stop_conditions", {})
param_mapping = {
"temperature": sampling_opts.get("temperature"),
"top_p": sampling_opts.get("top_p"),
"top_k": sampling_opts.get("top_k"),
"max_new_tokens": stop_conditions.get("max_tokens"),
"ignore_eos": stop_conditions.get("ignore_eos"),
}
else:
param_mapping = {
"temperature": request.get("temperature"),
"top_p": request.get("top_p"),
"top_k": request.get("top_k"),
"max_new_tokens": request.get("max_tokens"),
}
return {k: v for k, v in param_mapping.items() if v is not None}
def _get_input_param(self, request: GenerateRequest) -> dict:
assert self._input_param_manager is not None, "Engine not initialized"
request_input = self._input_param_manager.get_input_param(
request, use_tokenizer=not self._skip_tokenizer_init
)
return {
"prompt" if isinstance(request_input, str) else "input_ids": request_input
}
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unified entry point for the SGLang backend.
Usage:
python -m dynamo.sglang.unified_main <sglang args>
See dynamo/common/backend/README.md for architecture, response contract,
and feature gap details.
"""
from dynamo.common.backend.run import run
from dynamo.sglang.llm_engine import SglangLLMEngine
def main():
run(SglangLLMEngine)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""TensorRT-LLM LLMEngine implementation for the unified backend.
See dynamo/common/backend/README.md for architecture, response contract,
and feature gap details.
"""
from __future__ import annotations
import dataclasses
import logging
import re
from collections.abc import AsyncGenerator
from typing import Any
from tensorrt_llm.llmapi import KvCacheConfig, SchedulerConfig
from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.sampling_params import GuidedDecodingParams
from torch.cuda import device_count
from dynamo._core import Context
from dynamo.common.backend.engine import (
EngineConfig,
GenerateChunk,
GenerateRequest,
LLMEngine,
)
from dynamo.common.backend.worker import WorkerConfig
from dynamo.llm import ModelInput
from dynamo.trtllm.args import parse_args
from dynamo.trtllm.engine import Backend, TensorRTLLMEngine
logger = logging.getLogger(__name__)
class TrtllmLLMEngine(LLMEngine):
def __init__(
self,
engine_args: dict[str, Any],
model_name: str,
served_model_name: str | None = None,
max_seq_len: int | None = None,
max_batch_size: int | None = None,
max_num_tokens: int | None = None,
kv_block_size: int = 32,
):
self.engine_args = engine_args
self.model_name = model_name
self.served_model_name = served_model_name
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.max_num_tokens = max_num_tokens
self.kv_block_size = kv_block_size
self._engine: TensorRTLLMEngine | None = None
self._default_sampling_params = SamplingParams(detokenize=False)
self._active_requests: dict[str, Any] = {}
@classmethod
async def from_args(
cls, argv: list[str] | None = None
) -> tuple[TrtllmLLMEngine, WorkerConfig]:
config = parse_args(argv)
gpus_per_node = config.gpus_per_node or device_count()
engine_args = {
"model": str(config.model),
"scheduler_config": SchedulerConfig(),
"tensor_parallel_size": config.tensor_parallel_size,
"pipeline_parallel_size": config.pipeline_parallel_size,
"backend": Backend.PYTORCH,
"kv_cache_config": KvCacheConfig(
free_gpu_memory_fraction=config.free_gpu_memory_fraction,
),
"gpus_per_node": gpus_per_node,
"max_num_tokens": config.max_num_tokens,
"max_seq_len": config.max_seq_len,
"max_beam_width": config.max_beam_width,
"max_batch_size": config.max_batch_size,
}
engine = cls(
engine_args=engine_args,
model_name=config.model,
served_model_name=config.served_model_name,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
max_num_tokens=config.max_num_tokens,
kv_block_size=config.kv_block_size,
)
worker_config = WorkerConfig.from_runtime_config(
config,
model_name=config.model,
served_model_name=config.served_model_name,
model_input=ModelInput.Tokens,
)
return engine, worker_config
async def start(self) -> EngineConfig:
self._engine = TensorRTLLMEngine(self.engine_args)
await self._engine.initialize()
return EngineConfig(
model=self.model_name,
served_model_name=self.served_model_name,
context_length=self.max_seq_len,
kv_cache_block_size=self.kv_block_size,
max_num_seqs=self.max_batch_size,
max_num_batched_tokens=self.max_num_tokens,
)
async def generate(
self, request: GenerateRequest, context: Context
) -> AsyncGenerator[GenerateChunk, None]:
assert self._engine is not None, "Engine not initialized"
token_ids = request.get("token_ids", [])
sampling_params = self._override_sampling_params(
self._default_sampling_params, request
)
stop_conditions = request.get("stop_conditions", {})
max_tokens = stop_conditions.get("max_tokens")
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
elif self.max_seq_len is not None:
sampling_params.max_tokens = max(1, self.max_seq_len - len(token_ids))
ignore_eos = stop_conditions.get("ignore_eos")
if ignore_eos:
sampling_params.ignore_eos = ignore_eos
generation_result = self._engine.llm.generate_async(
inputs=token_ids,
sampling_params=sampling_params,
streaming=True,
)
request_id = context.id()
if request_id is not None:
self._active_requests[request_id] = generation_result
try:
num_output_tokens_so_far = 0
async for res in generation_result:
if not res.outputs and not res.finished:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total = len(output.token_ids)
out: GenerateChunk = {
"token_ids": output.token_ids[num_output_tokens_so_far:]
}
if output.finish_reason:
out["finish_reason"] = str(output.finish_reason)
if out.get("finish_reason") or res.finished:
if not out.get("finish_reason"):
out["finish_reason"] = "unknown"
prompt_tokens = len(token_ids)
out["completion_usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": next_total,
"total_tokens": prompt_tokens + next_total,
}
yield out
num_output_tokens_so_far = next_total
finally:
if request_id is not None:
self._active_requests.pop(request_id, None)
async def abort(self, context: Context) -> None:
request_id = context.id()
if request_id is not None:
generation_result = self._active_requests.get(request_id)
if generation_result is not None:
generation_result.abort()
logger.debug("Aborted request %s", request_id)
async def cleanup(self) -> None:
if self._engine is not None:
await self._engine.cleanup()
logger.info("TensorRT-LLM engine shutdown")
@staticmethod
def _override_sampling_params(
sampling_params: SamplingParams, request: GenerateRequest
) -> SamplingParams:
overrides = {
key: value
for key, value in request.get("sampling_options", {}).items()
if value is not None
}
guided_decoding = overrides.pop("guided_decoding", None)
if guided_decoding is not None and isinstance(guided_decoding, dict):
regex = guided_decoding.get("regex")
choice = guided_decoding.get("choice")
if choice and not regex:
valid_choices = [c for c in choice if c is not None]
if valid_choices:
regex = "(" + "|".join(re.escape(c) for c in valid_choices) + ")"
overrides["guided_decoding"] = GuidedDecodingParams(
json=guided_decoding.get("json"),
regex=regex,
grammar=guided_decoding.get("grammar"),
json_object=guided_decoding.get("json_object", False),
structural_tag=guided_decoding.get("structural_tag"),
)
return dataclasses.replace(sampling_params, **overrides)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unified entry point for the TensorRT-LLM backend.
Usage:
python -m dynamo.trtllm.unified_main <trtllm args>
See dynamo/common/backend/README.md for architecture, response contract,
and feature gap details.
"""
from dynamo.common.backend.run import run
from dynamo.trtllm.llm_engine import TrtllmLLMEngine
def main():
run(TrtllmLLMEngine)
if __name__ == "__main__":
main()
...@@ -63,9 +63,12 @@ def _preprocess_for_encode_config(config: Config) -> Dict[str, Any]: ...@@ -63,9 +63,12 @@ def _preprocess_for_encode_config(config: Config) -> Dict[str, Any]:
return config.__dict__ return config.__dict__
def parse_args() -> Config: def parse_args(argv: list[str] | None = None) -> Config:
"""Parse command-line arguments for the vLLM backend. """Parse command-line arguments for the vLLM backend.
Args:
argv: Command-line arguments. ``None`` means ``sys.argv[1:]``.
Returns: Returns:
Config: Parsed configuration object. Config: Parsed configuration object.
""" """
...@@ -94,7 +97,7 @@ def parse_args() -> Config: ...@@ -94,7 +97,7 @@ def parse_args() -> Config:
continue continue
vg._group_actions.append(action) vg._group_actions.append(action)
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args(argv)
dynamo_config = Config.from_cli_args(args) dynamo_config = Config.from_cli_args(args)
# Validate arguments # Validate arguments
......
...@@ -191,7 +191,8 @@ def build_sampling_params( ...@@ -191,7 +191,8 @@ def build_sampling_params(
sampling_params.detokenize = False sampling_params.detokenize = False
# Handle guided_decoding - convert to StructuredOutputsParams # Handle guided_decoding - convert to StructuredOutputsParams
guided_decoding = request["sampling_options"].get("guided_decoding") sampling_options = request.get("sampling_options", {})
guided_decoding = sampling_options.get("guided_decoding")
if guided_decoding is not None and isinstance(guided_decoding, dict): if guided_decoding is not None and isinstance(guided_decoding, dict):
sampling_params.structured_outputs = StructuredOutputsParams( sampling_params.structured_outputs = StructuredOutputsParams(
json=guided_decoding.get("json"), json=guided_decoding.get("json"),
...@@ -202,7 +203,7 @@ def build_sampling_params( ...@@ -202,7 +203,7 @@ def build_sampling_params(
) )
# Apply remaining sampling_options # Apply remaining sampling_options
for key, value in request["sampling_options"].items(): for key, value in sampling_options.items():
# Skip guided_decoding - already handled above # Skip guided_decoding - already handled above
if key == "guided_decoding": if key == "guided_decoding":
continue continue
...@@ -210,7 +211,7 @@ def build_sampling_params( ...@@ -210,7 +211,7 @@ def build_sampling_params(
setattr(sampling_params, key, value) setattr(sampling_params, key, value)
# Apply stop_conditions # Apply stop_conditions
for key, value in request["stop_conditions"].items(): for key, value in request.get("stop_conditions", {}).items():
if value is not None and hasattr(sampling_params, key): if value is not None and hasattr(sampling_params, key):
# Do not add stop key to sampling params - dynamo handles stop conditions directly # Do not add stop key to sampling params - dynamo handles stop conditions directly
if key == "stop": if key == "stop":
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""vLLM LLMEngine implementation for the unified backend.
See dynamo/common/backend/README.md for architecture, response contract,
and feature gap details.
"""
from __future__ import annotations
import logging
import os
import tempfile
from collections.abc import AsyncGenerator
from vllm.inputs import TokensPrompt
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
from dynamo._core import Context
from dynamo.common.backend.engine import (
EngineConfig,
GenerateChunk,
GenerateRequest,
LLMEngine,
)
from dynamo.common.backend.worker import WorkerConfig
from dynamo.llm import ModelInput
from dynamo.vllm.args import parse_args
from .handlers import build_sampling_params
logger = logging.getLogger(__name__)
class VllmLLMEngine(LLMEngine):
def __init__(self, engine_args):
self.engine_args = engine_args
self.engine_client = None
self._vllm_config = None
self._default_sampling_params = None
self._prometheus_temp_dir = None
self._model_max_len = None
@classmethod
async def from_args(
cls, argv: list[str] | None = None
) -> tuple[VllmLLMEngine, WorkerConfig]:
config = parse_args(argv)
if not config.served_model_name:
config.served_model_name = (
config.engine_args.served_model_name
) = config.model
engine = cls(config.engine_args)
worker_config = WorkerConfig.from_runtime_config(
config,
model_name=config.model,
served_model_name=config.served_model_name,
model_input=ModelInput.Tokens,
)
return engine, worker_config
async def start(self) -> EngineConfig:
os.environ["VLLM_NO_USAGE_STATS"] = "1"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
self._prometheus_temp_dir = tempfile.TemporaryDirectory(
prefix="vllm_prometheus_"
)
os.environ["PROMETHEUS_MULTIPROC_DIR"] = self._prometheus_temp_dir.name
self._default_sampling_params = (
self.engine_args.create_model_config().get_diff_sampling_param()
)
vllm_config = self.engine_args.create_engine_config(
usage_context=UsageContext.OPENAI_API_SERVER
)
self._vllm_config = vllm_config
self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=UsageContext.OPENAI_API_SERVER,
)
self._model_max_len = getattr(
getattr(vllm_config, "model_config", None), "max_model_len", None
)
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks or 0
block_size = vllm_config.cache_config.block_size
return EngineConfig(
model=self.engine_args.model,
served_model_name=self.engine_args.served_model_name,
context_length=self._model_max_len,
kv_cache_block_size=block_size,
total_kv_blocks=num_gpu_blocks,
max_num_seqs=vllm_config.scheduler_config.max_num_seqs,
max_num_batched_tokens=vllm_config.scheduler_config.max_num_batched_tokens,
)
async def generate(
self, request: GenerateRequest, context: Context
) -> AsyncGenerator[GenerateChunk, None]:
assert self.engine_client is not None, "Engine not initialized"
assert self._default_sampling_params is not None, "Engine not initialized"
request_id = context.id()
token_ids = request.get("token_ids", [])
prompt = TokensPrompt(prompt_token_ids=token_ids)
# TODO: remove dict() once build_sampling_params accepts GenerateRequest
sampling_params = build_sampling_params(
dict(request), self._default_sampling_params, self._model_max_len
)
gen = self.engine_client.generate(prompt, sampling_params, request_id)
num_output_tokens_so_far = 0
async for res in gen:
if not res.outputs:
yield {
"finish_reason": "error: No outputs from vLLM engine",
"token_ids": [],
}
break
output = res.outputs[0]
next_total = len(output.token_ids)
out: GenerateChunk = {
"token_ids": output.token_ids[num_output_tokens_so_far:]
}
if output.finish_reason:
out["finish_reason"] = str(output.finish_reason)
prompt_tokens = len(res.prompt_token_ids) if res.prompt_token_ids else 0
out["completion_usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": next_total,
"total_tokens": prompt_tokens + next_total,
}
yield out
num_output_tokens_so_far = next_total
async def abort(self, context: Context) -> None:
request_id = context.id()
if self.engine_client is not None and request_id is not None:
await self.engine_client.abort(request_id)
logger.debug("Aborted request %s", request_id)
async def cleanup(self) -> None:
if self.engine_client is not None:
self.engine_client.shutdown()
if self._prometheus_temp_dir is not None:
self._prometheus_temp_dir.cleanup()
logger.info("vLLM engine shutdown")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unified entry point for the vLLM backend.
Usage:
python -m dynamo.vllm.unified_main <vllm args>
See dynamo/common/backend/README.md for architecture, response contract,
and feature gap details.
"""
from dynamo.common.backend.run import run
from dynamo.vllm.llm_engine import VllmLLMEngine
def main():
run(VllmLLMEngine)
if __name__ == "__main__":
main()
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Aggregated serving with the sample (echo) backend.
# GPUs: 0 (CPU-only)
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/launch_utils.sh" # print_launch_banner, wait_any_exit
# Default values
MODEL_NAME="${MODEL_NAME:-sample-model}"
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model-name)
MODEL_NAME="$2"
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model-name <name> Specify model name (default: $MODEL_NAME)"
echo " -h, --help Show this help message"
echo ""
echo "Any additional options are passed through to sample_main."
exit 0
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
print_launch_banner "Launching Sample Aggregated Serving" "$MODEL_NAME" "$HTTP_PORT"
# run frontend
python3 -m dynamo.frontend &
# run sample worker
python3 -m dynamo.common.backend.sample_main \
--model-name "$MODEL_NAME" \
"${EXTRA_ARGS[@]}" &
# Exit on first worker failure; kill 0 in the EXIT trap tears down the rest
wait_any_exit
...@@ -15,6 +15,7 @@ source "$SCRIPT_DIR/../../../common/launch_utils.sh" # print_launch_banner, wait ...@@ -15,6 +15,7 @@ source "$SCRIPT_DIR/../../../common/launch_utils.sh" # print_launch_banner, wait
# Default values # Default values
MODEL="Qwen/Qwen3-0.6B" MODEL="Qwen/Qwen3-0.6B"
ENABLE_OTEL=false ENABLE_OTEL=false
USE_UNIFIED=false
# Parse command line arguments # Parse command line arguments
EXTRA_ARGS=() EXTRA_ARGS=()
...@@ -28,11 +29,16 @@ while [[ $# -gt 0 ]]; do ...@@ -28,11 +29,16 @@ while [[ $# -gt 0 ]]; do
ENABLE_OTEL=true ENABLE_OTEL=true
shift shift
;; ;;
--unified)
USE_UNIFIED=true
shift
;;
-h|--help) -h|--help)
echo "Usage: $0 [OPTIONS]" echo "Usage: $0 [OPTIONS]"
echo "Options:" echo "Options:"
echo " --model-path <name> Specify model (default: $MODEL)" echo " --model-path <name> Specify model (default: $MODEL)"
echo " --enable-otel Enable OpenTelemetry tracing" echo " --enable-otel Enable OpenTelemetry tracing"
echo " --unified Use unified_main entry point (Worker)"
echo " -h, --help Show this help message" echo " -h, --help Show this help message"
echo "" echo ""
echo "Additional SGLang/Dynamo flags can be passed and will be forwarded" echo "Additional SGLang/Dynamo flags can be passed and will be forwarded"
...@@ -66,8 +72,12 @@ OTEL_SERVICE_NAME=dynamo-frontend \ ...@@ -66,8 +72,12 @@ OTEL_SERVICE_NAME=dynamo-frontend \
python3 -m dynamo.frontend & python3 -m dynamo.frontend &
# run worker with metrics enabled # run worker with metrics enabled
WORKER_MODULE="dynamo.sglang"
if [ "$USE_UNIFIED" = true ]; then
WORKER_MODULE="dynamo.sglang.unified_main"
fi
OTEL_SERVICE_NAME=dynamo-worker DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \ OTEL_SERVICE_NAME=dynamo-worker DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \
python3 -m dynamo.sglang \ python3 -m "$WORKER_MODULE" \
--model-path "$MODEL" \ --model-path "$MODEL" \
--served-model-name "$MODEL" \ --served-model-name "$MODEL" \
--page-size 16 \ --page-size 16 \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment