Unverified Commit f00d700e authored by Alec's avatar Alec Committed by GitHub
Browse files

refactor: remove old examples with old UX (#1899)

parent c7080419
../../docs/examples/hello_world.md
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
message: "earth"
port: 8000
ServiceArgs:
workers: 1
resources:
cpu: "1"
Middle:
message: "moon"
ServiceArgs:
workers: 2
resources:
cpu: "1"
Backend:
message: "mars"
ServiceArgs:
workers: 1
resources:
cpu: "1"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: hello-world
spec:
services:
Frontend:
dynamoNamespace: hello-world
componentType: main
replicas: 1
resources:
requests:
cpu: "1"
memory: "2Gi"
limits:
cpu: "1"
memory: "2Gi"
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.3.1
workingDir: /workspace/examples/hello_world
args:
- dynamo
- serve
- hello_world:Frontend
- --system-app-port
- "5000"
- --enable-system-app
- --use-default-health-checks
- --service-name
- Frontend
- -f
- ./config.yaml
Middle:
dynamoNamespace: hello-world
replicas: 1
resources:
requests:
cpu: "1"
memory: "2Gi"
limits:
cpu: "1"
memory: "2Gi"
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.3.1
workingDir: /workspace/examples/hello_world
args:
- dynamo
- serve
- hello_world:Middle
- --system-app-port
- "5000"
- --enable-system-app
- --use-default-health-checks
- --service-name
- Middle
- -f
- ./config.yaml
Backend:
dynamoNamespace: hello-world
replicas: 1
resources:
requests:
cpu: "1"
memory: "2Gi"
limits:
cpu: "1"
memory: "2Gi"
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.3.1
workingDir: /workspace/examples/hello_world
args:
- dynamo
- serve
- hello_world:Backend
- --system-app-port
- "5000"
- --enable-system-app
- --use-default-health-checks
- --service-name
- Backend
- -f
- ./config.yaml
# Deployment Examples
This directory contains a hello world example which implements a simplified disaggregated serving architecture used for deploying Large Language Models (LLMs). It removes the LLM related inference code and focuses on how Dynamo handles routing, task queue and metadata communication between prefill and decode workers.
## Components
- frontend: A simple http server handles incoming requests
- processor: A pre/post processing server and invokes router server
- router: Handles API requests and routes them to appropriate workers based on specified strategy
- worker: A dummy decode worker
- prefill-worker: A dummy prefill worker
## Deployment Architectures
This figure shows an overview of the major components to deploy:
```
+----------------+
| prefill worker |-------+
| | |
+----------------+ | pull
v
+------+ +-----------+ +------------------+ push +---------------+
| HTTP |----->| processor |----->| decode/monolith |------------>| prefill queue |
| |<-----| |<-----| worker | | |
+------+ +-----------+ +------------------+ +---------------+
| ^
query best | | return
worker | | worker_id
| | +------------------+
| +---------| router |
+------------->| |
+------------------+
```
## The Aggregated Deployment
In this example, we will use 2 nodes to demo the disagg serving.
- Node 1
- Runs NATS and etcd services
- Deploys Frontend, Processor and Router
- Deploys DummyWorker as the monolith worker
- Node 2
- Deploys DummyWorker as the monolith worker
### Prerequisites
On Node 1, start required services (etcd and NATS) using [Docker Compose](../../../deploy/metrics/docker-compose.yml)
```bash
docker compose -f deploy/metrics/docker-compose.yml up -d
```
### Run the Deployment
1. Set environment variables for NATS and etcd services
```bash
export NATS_SERVER="nats://Node_1_IP_ADDRESS:4222"
export ETCD_ENDPOINTS="http://Node_1_IP_ADDRESS:2379"
```
2. Launch Frontend, Processor and Router services:
```
cd dynamo/examples/hello_world/disagg_skeleton
dynamo serve components.graph:Frontend
```
3. Open a new terminal on Node 1 and deploy Worker service
```
export NATS_SERVER="nats://Node_1_IP_ADDRESS:4222"
export ETCD_ENDPOINTS="http://Node_1_IP_ADDRESS:2379"
cd dynamo/examples/hello_world/disagg_skeleton
dynamo serve components.worker:DummyWorker
```
4. Go to Node 2 and start Worker service as in step 3.
Now you should see both workers are ready in Node 1's terminal.
5. Query the Frontend with following two prompts. The router would assign different workers for each prompt and you can observe it from the responses.
- `Response: {"worker_output":"Tell me a joke_GeneratedBy_NODE1HOSTNAME","request_id":"id_number"}`
- `Response: {"worker_output":"Which team won 2020 World Series_GeneratedBy_NODE2HOSTNAME","request_id":"id_number"}`
```
curl -X 'POST' \
'http://localhost:8000/generate' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"prompt": "Tell me a joke",
"request_id":"id_number"
}'
curl -X 'POST' \
'http://localhost:8000/generate' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"prompt": "Which team won 2020 World Series",
"request_id":"id_number"
}'
```
6. Then modify the prompt and you will notice prompts with similar prefix will be routed to the same worker due to the simply routing algorithm used in this demo. For example, following query will be routed to the worker proceesed "Tell me a joke" prompt.
```
curl -X 'POST' \
'http://localhost:8000/generate' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"prompt": "Tell me a fact",
"request_id":"id_number"
}'
```
-`Response: {"worker_output":"Tell me a fact_GeneratedBy_NODE1HOSTNAME","request_id":"id_number"}`
## The Disaggregated Deployment
In this example, we will use 3 nodes to demo the disagg serving.
- Node 1
- Runs NATS and etcd services
- Deploys Frontend and Processor
- Deploys DummyWorker as the decode worker
- Node 2
- Deploys DummyWorker as the decode worker
- Node 3
- Deploys Prefill as the prefill worker
### Run the Deployment
1. Repeat step 1 to 4 to deploy Frontend, Processor, Router and 2 Workers as decode worker
2. Go to Node 3 and start the prefill worker.
```
export NATS_SERVER="nats://Node_1_IP_ADDRESS:4222"
export ETCD_ENDPOINTS="http://Node_1_IP_ADDRESS:2379"
cd dynamo/examples/hello_world/disagg_skeleton
dynamo serve components.prefill_worker:PrefillWorker
```
3. Query the Frontend. This time decode workers push requests to the prefill queue, and prefill worker pulles task from the queue to simulate the prefill task. The actual prefill is skipped in this demo.
```
curl -X 'POST' \
'http://localhost:8000/generate' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"prompt": "This is prefill disagg serving example",
"request_id":"12345"
}'
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
import sys
from components.processor import Processor
from components.utils import GeneralRequest
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from dynamo.sdk import DYNAMO_IMAGE, api, depends, service
logger = logging.getLogger(__name__)
app = FastAPI(title="Hello World LLM")
@service(
dynamo={"namespace": "dynamo-demo"},
image=DYNAMO_IMAGE,
app=app,
)
class Frontend:
processor = depends(Processor)
def __init__(self):
signal.signal(signal.SIGTERM, self.handle_exit)
signal.signal(signal.SIGINT, self.handle_exit)
def handle_exit(self, signum, frame):
logger.debug(f"Received signal {signum}, shutting down...")
sys.exit(0)
@api()
async def generate(self, prompt, request_id): # from request body keys
"""Stream results from the pipeline."""
logger.info(f"Received: {prompt=},{request_id=}")
async def content_generator():
frontend_request = GeneralRequest(
prompt=prompt, request_id=request_id
).model_dump_json()
async for response in self.processor.processor_generate(frontend_request):
yield f"Response: {response}\n"
return StreamingResponse(content_generator())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.frontend import Frontend
from components.kv_router import Router
from components.processor import Processor
Frontend.link(Processor).link(Router)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from difflib import SequenceMatcher
from typing import AsyncIterator
from components.utils import check_required_workers
from components.worker import DummyWorker
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
WorkerId = str
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo-demo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Router:
"""
Request handler for the generate endpoint
"""
kv_cache: dict[str, str] = {}
threshold = 0.6
worker = depends(DummyWorker)
def __init__(self):
self.min_workers = 2
@async_on_start
async def async_init(self):
print("in kv router async_init")
self.runtime = dynamo_context["runtime"]
self.workers_client = (
await self.runtime.namespace("dynamo-demo")
.component("DummyWorker")
.endpoint("worker_generate")
.client()
)
await check_required_workers(self.workers_client, self.min_workers, "kv router")
print("KV Router initialized")
def _cost_function(self, request_prompt):
worker_ids = self.workers_client.instance_ids()
num_workers = len(worker_ids)
max_hit_rate = -1.0
for curr_id in self.kv_cache.keys():
# Estimate hit rate by string matching
hit_rate = SequenceMatcher(
None, self.kv_cache[curr_id], request_prompt
).ratio()
if hit_rate > max_hit_rate:
max_hit_rate = hit_rate
max_id = curr_id
print(f"{max_hit_rate=},{len(self.kv_cache.keys())=}")
if max_hit_rate > self.threshold:
# Found the hit rate larger than the threshold
return max_id, max_hit_rate
elif len(self.kv_cache.keys()) == num_workers:
# Cache is already full, return the max rate
return max_id, max_hit_rate
else:
# Add current request into the cache
for curr_id in worker_ids:
if curr_id not in self.kv_cache.keys():
self.kv_cache[curr_id] = request_prompt
break
return curr_id, -1
# A dummy hit rate checking endpoint
# The actual worker selection is based on custom cost function
# See details at examples/llm/components/kv_router.py
@endpoint()
async def check_hit_rate(self, request_prompt: str) -> AsyncIterator[WorkerId]:
max_id, max_hit_rate = self._cost_function(request_prompt)
yield f"{max_id}_{max_hit_rate}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import socket
import sys
from components.utils import NixlMetadataStore, PrefillQueue, RemotePrefillRequest
from vllm.distributed.device_communicators.nixl import NixlMetadata
from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo-demo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class PrefillWorker:
def __init__(self):
self._loaded_metadata = set()
self.initialized = False
self.hostname = socket.gethostname()
self.engine_id = self.hostname
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
# create dummy meta data
metadata = NixlMetadata(
engine_id=self.engine_id,
agent_metadata=[],
kv_caches_base_addr=[[]],
num_blocks=0,
)
self._metadata_store = NixlMetadataStore("dynamo-nixl", runtime)
await self._metadata_store.put(metadata.engine_id, metadata)
task = asyncio.create_task(self.prefill_queue_handler())
def prefill_queue_handler_cb(fut):
try:
fut.result()
print("prefill queue handler exited successfully")
except Exception as e:
print(f"[ERROR] prefill queue handler failed: {e!r}")
sys.exit(1)
task.add_done_callback(prefill_queue_handler_cb)
print("PrefillWorker initialized")
async def prefill_queue_handler(self):
print("Prefill queue handler entered")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = "DummyLLM"
print(f"Prefill queue: {prefill_queue_nats_server}:{prefill_queue_stream_name}")
self.initialized = True
# TODO: integrate prefill_queue to a dynamo endpoint
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
) as prefill_queue:
print("prefill queue handler started")
while True:
# TODO: this might add a small overhead to pull prefill from nats
# need to test and check how much overhead it is
prefill_request = await prefill_queue.dequeue_prefill_request()
if prefill_request is not None:
print(f"Dequeued prefill request: {prefill_request.request_id}")
async for _ in self.prefill_generate(prefill_request):
pass
async def prefill_generate(self, request: RemotePrefillRequest):
# TODO check if metadata has changed
# and reload - currently only loading once
print(f"prefill invoked {request.engine_id}{self._loaded_metadata=}")
if request.engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
# await self.engine_client.add_remote_nixl_metadata(remote_metadata)
print(f"Received nixl metadata from host {remote_metadata.engine_id}")
self._loaded_metadata.add(remote_metadata.engine_id)
print("Prefill invoked and will read KV cache from worker and write it back")
yield "prefill invoked"
@endpoint()
async def mock(self, req: RemotePrefillRequest):
yield f"mock_response: {req}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Protocol
from components.kv_router import Router
from components.utils import GeneralRequest, GeneralResponse, check_required_workers
from components.worker import DummyWorker
from dynamo._core import Client
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.dependency import DynamoClient
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo-demo",
},
workers=1,
)
class Processor(Protocol):
"""
vLLM pre and post processing
"""
router: DynamoClient = depends(Router)
router_mode: str
min_workers: int
worker_client: Client
def __init__(self):
self.router_mode = "kv"
self.min_workers = 2
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = DummyWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("worker_generate")
.client()
)
await check_required_workers(
self.worker_client, self.min_workers, tag="processor"
)
async def _generate(
self,
raw_request: GeneralRequest,
):
if self.router_mode == "kv":
async for route_response in self.router.check_hit_rate(raw_request.prompt):
worker_id, prefix_hit_rate = route_response.split("_")
prefix_hit_rate = float(prefix_hit_rate)
print(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
break
if worker_id == "":
engine_generator = await self.worker_client.random(
raw_request.model_dump_json()
)
else:
engine_generator = await self.worker_client.direct(
raw_request.model_dump_json(),
int(worker_id),
)
elif self.router_mode == "random":
engine_generator = await self.worker_client.random(
raw_request.model_dump_json()
)
elif self.router_mode == "round-robin":
engine_generator = await self.worker_client.round_robin(
raw_request.model_dump_json()
)
async for resp in engine_generator:
yield GeneralResponse.model_validate_json(resp.data())
@endpoint()
async def processor_generate(self, raw_request: GeneralRequest):
async for response in self._generate(raw_request):
yield response.model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import ClassVar, Optional
import msgspec
from nats.aio.client import Client as NATS
from nats.errors import Error as NatsError
from nats.js.client import JetStreamContext
from nats.js.errors import NotFoundError
from pydantic import BaseModel
from vllm.distributed.device_communicators.nixl import NixlMetadata
from dynamo._core import Client
from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__)
class GeneralRequest(BaseModel):
prompt: str = "user input"
request_id: str = "id_string"
class GeneralResponse(BaseModel):
worker_output: str = "generated output"
request_id: str = "id_string"
class RemotePrefillRequest(msgspec.Struct, omit_defaults=True, dict=True):
engine_id: str = "Engine ID"
request_id: str = "id_string"
class NATSQueue:
_instance: ClassVar[Optional["NATSQueue"]] = None
_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
def __init__(
self,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
self.nats_url = nats_server
self._nc: Optional[NATS] = None
self._js: Optional[JetStreamContext] = None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self._stream_name = stream_name.replace("/", "_").replace("\\", "_")
self._subject = f"{self._stream_name}.*"
self.dequeue_timeout = dequeue_timeout
self._subscriber: Optional[JetStreamContext.PullSubscription] = None
@classmethod
@asynccontextmanager
async def get_instance(
cls,
*,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
"""Get or create a singleton instance of NATSq"""
# TODO: check if this _lock is needed with GIL
async with cls._lock:
if cls._instance is None:
cls._instance = cls(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
await cls._instance.connect()
try:
yield cls._instance
except Exception:
if cls._instance:
await cls._instance.close()
cls._instance = None
raise
# TODO: check to see if this can be replaced by something like get_instance().close()
@classmethod
async def shutdown(cls):
"""Explicitly close the singleton instance if it exists"""
async with cls._lock:
if cls._instance:
await cls._instance.close()
cls._instance = None
async def connect(self):
"""Establish connection and create stream if needed"""
try:
if self._nc is None:
self._nc = NATS()
await self._nc.connect(self.nats_url)
self._js = self._nc.jetstream()
# Check if stream exists, if not create it
try:
await self._js.stream_info(self._stream_name)
except NotFoundError:
await self._js.add_stream(
name=self._stream_name, subjects=[self._subject]
)
# Create persistent subscriber
self._subscriber = await self._js.pull_subscribe(
f"{self._stream_name}.queue", durable="worker-group"
)
except NatsError as e:
await self.close()
raise ConnectionError(f"Failed to connect to NATS: {e}")
async def ensure_connection(self):
"""Ensure we have an active connection"""
if self._nc is None or self._nc.is_closed:
await self.connect()
async def close(self):
"""Close the connection when done"""
if self._nc:
await self._nc.close()
self._nc = None
self._js = None
self._subscriber = None
# TODO: is enqueue/dequeue_object a better name for a general queue?
async def enqueue_task(self, task_data: bytes) -> None:
"""
Enqueue a task using msgspec-encoded data
"""
await self.ensure_connection()
try:
await self._js.publish(f"{self._stream_name}.queue", task_data) # type: ignore
except NatsError as e:
raise RuntimeError(f"Failed to enqueue task: {e}")
async def dequeue_task(self) -> Optional[bytes]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await self.ensure_connection()
try:
msgs = await self._subscriber.fetch(1, timeout=self.dequeue_timeout) # type: ignore
if msgs:
msg = msgs[0]
await msg.ack()
return msg.data
return None
except asyncio.TimeoutError:
return None
except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}")
async def get_queue_size(self) -> int:
"""Get the number of messages currently in the queue"""
await self.ensure_connection()
try:
# Get consumer info to get pending messages count
consumer_info = await self._js.consumer_info( # type: ignore
self._stream_name, "worker-group"
)
# Return number of pending messages (real-time queue size)
return consumer_info.num_pending
except NatsError as e:
raise RuntimeError(f"Failed to get queue size: {e}")
class PrefillQueue(NATSQueue):
"""
A wrapper of NATSQueue for PrefillRequest.
The stream name is forced to be "prefill_queue".
"""
def __init__(
self,
stream_name="prefill_queue",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
super().__init__(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
async def enqueue_prefill_request(
self, prefill_request: RemotePrefillRequest
) -> None:
encoded_request = msgspec.json.encode(prefill_request)
await self.enqueue_task(encoded_request)
async def dequeue_prefill_request(self) -> Optional[RemotePrefillRequest]:
encoded_request = await self.dequeue_task()
if encoded_request is not None:
prefill_request = msgspec.json.decode(
encoded_request, type=RemotePrefillRequest
)
return prefill_request
else:
return None
class NixlMetadataStore:
NIXL_METADATA_KEY = "nixl_metadata"
def __init__(self, namespace: str, runtime: DistributedRuntime) -> None:
self._namespace = namespace
# TODO Remove metadata from etcd on delete
self._stored: set[str] = set()
self._cached: dict[str, NixlMetadata] = {}
self._client = runtime.etcd_client()
if self._client is None:
raise Exception("Cannot be used with static workers")
self._key_prefix = f"{self._namespace}/{NixlMetadataStore.NIXL_METADATA_KEY}"
async def put(self, engine_id, metadata: NixlMetadata):
serialized_metadata = msgspec.msgpack.encode(metadata)
key = "/".join([self._key_prefix, engine_id])
await self._client.kv_put(key, serialized_metadata, None)
self._stored.add(engine_id)
async def get(self, engine_id) -> NixlMetadata:
try:
if engine_id in self._cached:
return self._cached[engine_id]
key = "/".join([self._key_prefix, engine_id])
key_values = await self._client.kv_get_prefix(key)
deserialized_metadata = None
for item in key_values:
deserialized_metadata = msgspec.msgpack.decode(
item["value"], type=NixlMetadata
)
break
if deserialized_metadata is None:
raise Exception("metadata not found in etcd")
self._cached[engine_id] = deserialized_metadata
# TODO watch for changes and update cache
# self._client.add_watch_callback(
# key,
# self._watch_callback,
# )
except Exception as e:
raise Exception(f"Error retrieving metadata for engine {engine_id}") from e
return deserialized_metadata
async def check_required_workers(
workers_client: Client,
required_workers: int,
on_change=True,
poll_interval=5,
tag="",
):
"""Wait until the minimum number of workers are ready."""
worker_ids = workers_client.instance_ids()
num_workers = len(worker_ids)
new_count = -1 # Force to print "waiting for worker" once
while num_workers < required_workers:
if (not on_change) or new_count != num_workers:
num_workers = new_count if new_count >= 0 else num_workers
print(
f" {tag} Waiting for more workers to be ready.\n"
f" Current: {num_workers},"
f" Required: {required_workers}"
)
await asyncio.sleep(poll_interval)
worker_ids = workers_client.instance_ids()
new_count = len(worker_ids)
print(f"Workers ready: {worker_ids}")
return worker_ids
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import socket
from components.utils import (
GeneralRequest,
GeneralResponse,
NixlMetadataStore,
PrefillQueue,
RemotePrefillRequest,
)
from vllm.distributed.device_communicators.nixl import NixlMetadata
from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo-demo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class DummyWorker:
def __init__(self):
self.hostname = socket.gethostname()
self.do_remote_prefill = True
self.model_name = "DummyLLM"
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self._prefill_queue_stream_name = self.model_name
logger.info(
f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
)
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
if self.do_remote_prefill:
# Create dummy Nixl meta data
metadata = NixlMetadata(
engine_id=self.hostname,
agent_metadata=[],
kv_caches_base_addr=[[]],
num_blocks=0,
)
metadata_store = NixlMetadataStore("dynamo-nixl", runtime)
await metadata_store.put(metadata.engine_id, metadata)
self.disaggregated_router = "DummyDisaggregateRouter"
logger.info("VllmWorker has been initialized")
def get_remote_prefill_request_callback(self):
# TODO: integrate prefill_queue to dynamo endpoint
async def callback(request: RemotePrefillRequest):
print(
f"enqueue request {self._prefill_queue_nats_server}, \
{self._prefill_queue_stream_name},{request.engine_id=}"
)
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
await prefill_queue.enqueue_prefill_request(request)
return callback
@endpoint()
async def worker_generate(self, request: GeneralRequest):
# TODO: consider prefix hit when deciding prefill locally or remotely
if self.disaggregated_router is not None:
# decision = (
# absolute_prefill_length > self.max_local_prefill_length
# and queue_size < self.max_prefill_queue_size )
# Disagg router decision is based on prefill length and queue size
# Always set to True in this demo (see details at disagg_router.py)
disagg_router_decision = True
else:
# always prefill remotely if no disaggregated router is provided
disagg_router_decision = True
if self.do_remote_prefill and disagg_router_decision:
## Mimic the process of enqueue request for prefill
prefill_request = RemotePrefillRequest(
engine_id=self.hostname, request_id=request.request_id
)
callback = self.get_remote_prefill_request_callback()
await callback(prefill_request)
print(f"{self.hostname}: Worker invoked")
yield GeneralResponse(
request_id=request.request_id,
worker_output=request.prompt + "_GeneratedBy_" + self.hostname,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sdk import (
DYNAMO_IMAGE,
api,
depends,
endpoint,
liveness,
on_shutdown,
readiness,
service,
)
from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__)
"""
Pipeline Architecture:
Users/Clients (HTTP)
┌─────────────┐
│ Frontend │ HTTP API endpoint (/generate)
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Middle │
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Backend │
└─────────────┘
"""
class RequestType(BaseModel):
text: str
class ResponseType(BaseModel):
text: str
@service(
dynamo={
"namespace": "inference",
},
resources={"cpu": 1, "memory": "500Mi"},
workers=2,
image=DYNAMO_IMAGE,
)
class Backend:
def __init__(self) -> None:
logger.info("Starting backend")
config = ServiceConfig.get_instance()
self.message = config.get("Backend", {}).get("message", "back")
logger.info(f"Backend config message: {self.message}")
@endpoint()
async def generate(self, req: RequestType):
"""Generate tokens."""
req_text = req.text
logger.info(f"Backend received: {req_text}")
text = f"{req_text}-{self.message}"
for token in text.split():
yield f"Backend: {token}"
@on_shutdown
def shutdown(self):
logger.info("Shutting down backend")
@service(
dynamo={"namespace": "inference"},
image=DYNAMO_IMAGE,
)
class Middle:
backend = depends(Backend)
def __init__(self) -> None:
logger.info("Starting middle")
config = ServiceConfig.get_instance()
self.message = config.get("Middle", {}).get("message", "mid")
logger.info(f"Middle config message: {self.message}")
@endpoint()
async def generate(self, req: RequestType):
"""Forward requests to backend."""
req_text = req.text
logger.info(f"Middle received: {req_text}")
text = f"{req_text}-{self.message}"
next_request = RequestType(text=text).model_dump_json()
async for response in self.backend.generate(next_request):
logger.info(f"Middle received response: {response}")
yield f"Middle: {response}"
@on_shutdown
def shutdown(self):
logger.info("Shutting down middle")
@service(
dynamo={"namespace": "inference"},
image=DYNAMO_IMAGE,
# Example of kubernetes overrides if needed.
# kubernetes_overrides={
# "entrypoint": ["sh -c"],
# "cmd": ["echo hello from FrontEnd!"],
# },
)
class Frontend:
"""A simple frontend HTTP API that forwards requests to the dynamo graph."""
middle = depends(Middle)
def __init__(self) -> None:
# Configure logging
configure_dynamo_logging(service_name="Frontend")
logger.info("Starting frontend")
config = ServiceConfig.get_instance()
self.message = config.get("Frontend", {}).get("message", "front")
self.port = config.get("Frontend", {}).get("port", 8000)
logger.info(f"Frontend config message: {self.message}")
logger.info(f"Frontend config port: {self.port}")
# alternative syntax: @endpoint(transports=[DynamoTransport.HTTP])
@api()
async def generate(self, request: RequestType):
"""Stream results from the pipeline."""
logger.info(f"Frontend received: {request.text}")
async def content_generator():
async for response in self.middle.generate(request.model_dump_json()):
yield f"Frontend: {response}"
return StreamingResponse(content_generator())
@liveness
def is_alive(self):
return True
@readiness
def is_ready(self):
return True
@on_shutdown
def shutdown(self):
logger.info("Shutting down frontend")
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Hello World MultiNode Example
## Overview
This example demonstrates how to deploy workers into multinodes and route requests to different workers.
Pipeline Architecture:
```
Users/Clients (HTTP)
┌─────────────────────┐
│ Frontend (node 1) │ HTTP API endpoint (/generate)
└─────────────────────┘
│ dynamo/runtime
┌─────────────────────┐
│ Processor (node 1) │ ─────────────────
└─────────────────────┘ routing │
│ dynamo/runtime │ dynamo/runtime
▼ ▼
┌─────────────────────┐ ┌─────────────────────┐
│ Worker_1 (node 2) │ │ Worker_2 (node 3) │
└─────────────────────┘ └─────────────────────┘
```
## Component Descriptions
### Frontend Service
- Serves as the entry point for external HTTP requests
- Exposes a `/generate` HTTP API endpoint that clients can call
- Processes incoming text and passes it to the Middle service
### Processor Service
- Acts as an intermediary service in the pipeline
- Deployed on the same node as Frontend and receives requests from the Frontend
- Calls multiple workers based on the routing mode, random or round-robin.
### Worker Service
- Functions as the final service in the pipeline
- Deployed on a different node from Frontend and Processor
- Appends "GeneratedBy_HostName" to the text and yields tokens
## Prerequisites
Start required services (etcd and NATS) using [Docker Compose](../../../deploy/docker-compose.yml)
```bash
docker compose -f deploy/docker-compose.yml up -d
```
## Running the Single Worker Example
In this example, we will use two nodes to demo the multinode serving.
- Node 1
- Runs NATS and etcd services
- Deploys Frontend and Processor
- Node 2
- Deploys Worker
1. Set environment variables for NATS and etcd services
```bash
export NATS_SERVER="nats://Node_1_IP_ADDRESS:4222"
export ETCD_ENDPOINTS="http://Node_1_IP_ADDRESS:2379"
```
2. Launch Frontend and Processor services:
```bash
cd dynamo/examples/hello_world/multinode_example
dynamo serve components.graph:Frontend -f configs/one_worker.yaml
```
The `dynamo serve` command deploys the entire service graph, automatically handling the dependencies between Frontend, and Processor services. Since no worker is deployed yet, the service remains idle.
![text](./_img/waiting1worker.png)
3. Go to node 2 and launch Worker service
```bash
export NATS_SERVER="nats://Node_1_IP_ADDRESS:4222"
export ETCD_ENDPOINTS="http://Node_1_IP_ADDRESS:2379"
cd dynamo/examples/hello_world/multinode_example
dynamo serve components.worker:DummyWorker
```
You should see the worker is ready from node 1's terminal.
![text](./_img/1workerready.png)
4. Go back to node 1 and send request to frontend using curl:
```bash
curl -X 'POST' \
'http://localhost:8000/generate' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"prompt": "test prompt",
"request_id": "id_number"
}'
```
5. You should be able to see response as below:
`Response: {"worker_output":"test prompt_ProcessedBy_NODE1HOSTNAME_GeneratedBy_NODE2HOSTNAME","request_id":"id_number"}`
Here `NODE1HOSTNAME` is the hostname for node 1, and `NODE2HOSTNAME` is the hostname for node 2.
## Running the Two Workers Example
In this example, we will use three nodes to demo the multinode serving.
- Node 1
- Runs NATS and etcd services
- Deploys Frontend and Processor
- Node 2
- Deploys Worker 1
- Node 3
- Deploys Worker 2
1. Launch Frontend and Processor services using the `multi_worker.yaml` config from node 1. In this config file, we require 2 workers and set the router mode as **round robin**
```bash
dynamo serve components.graph:Frontend -f configs/multi_worker.yaml
```
The service is waiting for 2 workers this time.
2. Go to node 2 and node 3, launch worker service separately
```bash
export NATS_SERVER="nats://Node_1_IP_ADDRESS:4222"
export ETCD_ENDPOINTS="http://Node_1_IP_ADDRESS:2379"
dynamo serve components.worker:DummyWorker
```
You should see the following messages from node 1's terminal window when both workers are deployed
![text](./_img/2workerready.png)
3. Query the frontend using the same query as before, and run it multiple times. You should see following two responses in turn because of round-robin routing mode between 2 workers.
Response from worker 1: `Response: {"worker_output":"test prompt_ProcessedBy_NODE1HOSTNAME_GeneratedBy_NODE2HOSTNAME","request_id":"id_number"}`
Response from worker 2: `Response: {"worker_output":"test prompt_ProcessedBy_NODE1HOSTNAME_GeneratedBy_NODE3HOSTNAME","request_id":"id_number"}`
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from components.processor import Processor
from components.utils import GeneralRequest
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_endpoint, service
logger = logging.getLogger(__name__)
app = FastAPI(title="Hello World!")
@service(
dynamo={
"enabled": True,
"namespace": "dynamo-demo",
},
image=DYNAMO_IMAGE,
app=app,
)
class Frontend:
processor = depends(Processor)
@dynamo_endpoint(is_api=True)
async def generate(self, request: GeneralRequest): # from request body keys
"""Stream results from the pipeline."""
logger.info(f"-Frontend layer received: {request=}")
async def content_generator():
async for response in self.processor.generate(request.model_dump_json()):
yield f"Frontend: {response}"
return StreamingResponse(content_generator())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.frontend import Frontend
from components.processor import Processor
Frontend.link(Processor)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from typing import Protocol
from components.utils import GeneralRequest, GeneralResponse, check_required_workers
from components.worker import DummyWorker
from dynamo._core import Client
from dynamo.sdk import (
DYNAMO_IMAGE,
async_on_start,
depends,
dynamo_context,
dynamo_endpoint,
service,
)
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.dependency import DynamoClient
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo-demo",
},
image=DYNAMO_IMAGE,
)
class Processor(Protocol):
"""
Pre and Post Processing
"""
worker: DynamoClient = depends(DummyWorker)
router: str
hostname: str
min_workers: int
worker_client: Client
def __init__(self):
config = ServiceConfig.get_instance()
processor_config = config.get("Processor", {})
self.hostname = socket.gethostname()
self.min_workers = processor_config.get("min_worker", 1)
self.router = processor_config.get("router", "round-robin")
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = DummyWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
await check_required_workers(
self.worker_client, self.min_workers, tag="processor"
)
logger.info(f"----workers are all ready {self.worker_client.endpoint_ids()}")
async def _generate(
self,
raw_request: GeneralRequest,
):
raw_request.prompt = raw_request.prompt + "_ProcessedBy_" + self.hostname
if self.router == "random":
engine_generator = await self.worker_client.random(
raw_request.model_dump_json()
)
elif self.router == "round-robin":
engine_generator = await self.worker_client.round_robin(
raw_request.model_dump_json()
)
async for resp in engine_generator:
yield GeneralResponse.model_validate_json(resp.data())
@dynamo_endpoint()
async def generate(self, request: GeneralRequest):
"""Forward requests to backend."""
mid_request = request.model_dump_json()
logger.info(f"Received request{mid_request=}")
async for response in self._generate(request):
logger.debug(f"Received response: {response.model_dump_json()}")
yield response.model_dump_json()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment