Unverified Commit 10e91264 authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add multimodal example with disaggregated serving (#811)

parent 92bbbc39
...@@ -17,16 +17,18 @@ limitations under the License. ...@@ -17,16 +17,18 @@ limitations under the License.
# Multimodal Deployment Examples # Multimodal Deployment Examples
This directory contains examples and reference implementations for deploying a multimodal model with Dynamo. This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo.
The examples are based on the [llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model.
## Components ## Multimodal Aggregated Serving
### Components
- workers: For aggregated serving, we have two workers, [encode_worker](components/encode_worker.py) for encoding and [vllm_worker](components/worker.py) for prefilling and decoding. - workers: For aggregated serving, we have two workers, [encode_worker](components/encode_worker.py) for encoding and [vllm_worker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the vllm worker. - processor: Tokenizes the prompt and passes it to the vllm worker.
- frontend: Http endpoint to handle incoming requests. - frontend: Http endpoint to handle incoming requests.
### Deployment
#### Multimodal Aggregated serving
In this deployment, we have two workers, [encode_worker](components/encode_worker.py) and [vllm_worker](components/worker.py). In this deployment, we have two workers, [encode_worker](components/encode_worker.py) and [vllm_worker](components/worker.py).
The encode worker is responsible for encoding the image and passing the embeddings to the vllm worker via NATS. The encode worker is responsible for encoding the image and passing the embeddings to the vllm worker via NATS.
...@@ -69,3 +71,56 @@ You should see a response similar to this: ...@@ -69,3 +71,56 @@ You should see a response similar to this:
``` ```
" The image features a close-up view of the front of a bus, with a prominent neon sign clearly displayed. The bus appears to be slightly past its prime condition, beyond its out-of-service section. Inside the bus, we see a depth of text, with the sign saying \"out of service\". A wide array of windows line the side of the double-decker bus, making its overall appearance quite interesting and vintage." " The image features a close-up view of the front of a bus, with a prominent neon sign clearly displayed. The bus appears to be slightly past its prime condition, beyond its out-of-service section. Inside the bus, we see a depth of text, with the sign saying \"out of service\". A wide array of windows line the side of the double-decker bus, making its overall appearance quite interesting and vintage."
``` ```
## Multimodal Disaggregated serving
### Components
- workers: For disaggregated serving, we have three workers, [encode_worker](components/encode_worker.py) for encoding, [vllm_worker](components/worker.py) for decoding, and [prefill_worker](components/prefill_worker.py) for prefilling.
- processor: Tokenizes the prompt and passes it to the vllm worker.
- frontend: Http endpoint to handle incoming requests.
### Deployment
In this deployment, we have three workers, [encode_worker](components/encode_worker.py), [vllm_worker](components/worker.py), and [prefill_worker](components/prefill_worker.py).
For the Llava model, embeddings are only required during the prefill stage. As such, the encode worker is connected directly to the prefill worker.
The encode worker handles image encoding and transmits the resulting embeddings to the prefill worker via NATS.
The prefill worker performs the prefilling step and forwards the KV cache to the vllm worker for decoding.
For more details on the roles of the prefill and vllm workers, refer to the [LLM disaggregated serving](../llm/README.md) example.
This figure shows the flow of the deployment:
```
+------+ +-----------+ +------------------+ +------------------+ image url +---------------+
| HTTP |----->| processor |----->| vllm worker |----->| prefill worker |--------------------->| encode worker |
| |<-----| |<-----| (decode worker) |<-----| |<---------------------| |
+------+ +-----------+ +------------------+ +------------------+ image embeddings +---------------+
```
```bash
cd $DYNAMO_HOME/examples/multimodal
dynamo serve graphs.disagg:Frontend -f configs/disagg.yaml
```
### Client
In another terminal:
```bash
curl -X 'POST' \
'http://localhost:8000/generate' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"model":"llava-hf/llava-1.5-7b-hf",
"image":"http://images.cocodataset.org/val2017/000000324158.jpg",
"prompt":"Describe the mood and setting of this image in two sentences. What time of day do you think it is?",
"max_tokens":300
}' | jq
```
You should see a response similar to this:
```
" The image depicts a man moving across a field on a skateboard. The setting appears to be joyful, and this activity suggests that the man is enjoying an outdoor adventure. Additionally, a pet dog is probably accompanying, contributing to the positive mood. The mood and setting of the image appear lively and shoal. The sun is most likely low in the sky, as this would produce a nice daylight."
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import dynamo_context
logger = logging.getLogger(__name__)
class PyDisaggregatedRouter:
def __init__(
self,
runtime,
served_model_name,
max_local_prefill_length=1000,
max_prefill_queue_size=2,
):
self.runtime = runtime
self.served_model_name = served_model_name
self.max_local_prefill_length = max_local_prefill_length
self.max_prefill_queue_size = max_prefill_queue_size
async def async_init(self):
runtime = dynamo_context["runtime"]
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/disagg_router/",
{
"max_local_prefill_length": str(self.max_local_prefill_length),
"max_prefill_queue_size": str(self.max_prefill_queue_size),
},
)
async def prefill_remote(
self, prompt_length: int, prefix_hit_rate: float, queue_size: int
):
max_local_prefill_length = int(
await self.etcd_kv_cache.get("max_local_prefill_length")
)
max_prefill_queue_size = int(
await self.etcd_kv_cache.get("max_prefill_queue_size")
)
absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
# TODO: consider size of each request in the queue when making the decision
decision = (
absolute_prefill_length > max_local_prefill_length
and queue_size < max_prefill_queue_size
)
logger.info(
f"Remote prefill: {decision} (prefill length: {absolute_prefill_length}/{max_local_prefill_length}, prefill queue size: {queue_size}/{max_prefill_queue_size})"
)
return decision
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import signal
import sys
import torch
from components.encode_worker import EncodeWorker
from pydantic import BaseModel
from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
logger = logging.getLogger(__name__)
class RequestType(BaseModel):
text: str
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class PrefillWorker:
encode_worker = depends(EncodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self._loaded_metadata = set()
self.initialized = False
self.min_workers = 1
if self.engine_args.enable_chunked_prefill is not False:
logger.info("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.pipeline_parallel_size != 1:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.disable_async_output_proc is not True:
logger.info("Async output processing is not supported yet, setting to True")
self.engine_args.disable_async_output_proc = True
if self.engine_args.enforce_eager is not True:
logger.info("Prefill must be done eagerly, setting to True")
self.engine_args.enforce_eager = True
if self.engine_args.enable_prefix_caching is not False:
logger.info(
"Prefix caching is not supported yet in prefill worker, setting to False"
)
self.engine_args.enable_prefix_caching = False
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
@async_on_start
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
runtime = dynamo_context["runtime"]
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
.endpoint("encode")
.client()
)
await check_required_workers(self.encode_worker_client, self.min_workers)
metadata = self.engine_client.nixl_metadata
self._metadata_store = NixlMetadataStore("dynamo", runtime)
await self._metadata_store.put(metadata.engine_id, metadata)
task = asyncio.create_task(self.prefill_queue_handler())
def prefill_queue_handler_cb(fut):
try:
fut.result()
logger.info("prefill queue handler exited successfully")
except Exception as e:
logger.error(f"[ERROR] prefill queue handler failed: {e!r}")
sys.exit(1)
task.add_done_callback(prefill_queue_handler_cb)
logger.info("PrefillWorker initialized")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("PrefillWorker shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
async def prefill_queue_handler(self):
logger.info("Prefill queue handler entered")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
logger.info(
f"Prefill queue: {prefill_queue_nats_server}:{prefill_queue_stream_name}"
)
self.initialized = True
# TODO: integrate prefill_queue to a dynamo endpoint
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
) as prefill_queue:
logger.info("prefill queue handler started")
while True:
# TODO: this might add a small overhead to pull prefill from nats
# need to test and check how much overhead it is
prefill_request = await prefill_queue.dequeue_prefill_request()
if prefill_request is not None:
logger.info(
f"Dequeued prefill request: {prefill_request.request_id}"
)
async for _ in self.generate(prefill_request):
pass
async def generate(self, request: RemotePrefillRequest):
if request.multimodal_data_source["image_url"] is None:
raise ValueError("No image url provided for prefill request")
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
image_url=request.multimodal_data_source["image_url"],
).model_dump_json()
)
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(encode_response.data())
image_features = torch.tensor(
encode_output.image_features, device="cpu", dtype=torch.float16
)
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id,
decode_computed_block_ids=request.computed_block_ids,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if request.engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
logger.info(
f"Loaded nixl metadata from engine {request.engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(request.engine_id)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
embedding_size = image_features.shape[1]
padding_size = embedding_size - 1
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
+ request.prompt_token_ids[dummy_token_index + padding_size :]
)
async for _ in self.engine_client.generate(
request_id=request.request_id,
prompt=TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"image": image_features},
),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@dynamo_endpoint()
async def mock(self, req: RequestType):
yield f"mock_response: {req}"
...@@ -15,11 +15,18 @@ ...@@ -15,11 +15,18 @@
import asyncio import asyncio
import logging import logging
import os
import signal import signal
from typing import Optional
import torch import torch
from components.disagg_router import PyDisaggregatedRouter
from components.encode_worker import EncodeWorker from components.encode_worker import EncodeWorker
from components.prefill_worker import PrefillWorker
from transformers import LlavaForConditionalGeneration
from utils.logging import check_required_workers from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import ( from utils.protocol import (
EncodeRequest, EncodeRequest,
EncodeResponse, EncodeResponse,
...@@ -31,6 +38,7 @@ from vllm.entrypoints.openai.api_server import ( ...@@ -31,6 +38,7 @@ from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args, build_async_engine_client_from_engine_args,
) )
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
...@@ -47,11 +55,15 @@ logger = logging.getLogger(__name__) ...@@ -47,11 +55,15 @@ logger = logging.getLogger(__name__)
workers=1, workers=1,
) )
class VllmWorker: class VllmWorker:
# For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker = depends(PrefillWorker)
# For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker = depends(EncodeWorker) encode_worker = depends(EncodeWorker)
def __init__(self): def __init__(self):
self.client = None self.client = None
self.min_workers = 1 self.min_workers = 1
self.disaggregated_router: Optional[PyDisaggregatedRouter] = None
class_name = self.__class__.__name__ class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "") self.engine_args = parse_vllm_args(class_name, "")
self.do_remote_prefill = self.engine_args.remote_prefill self.do_remote_prefill = self.engine_args.remote_prefill
...@@ -60,10 +72,30 @@ class VllmWorker: ...@@ -60,10 +72,30 @@ class VllmWorker:
if self.engine_args.served_model_name is not None if self.engine_args.served_model_name is not None
else "vllm" else "vllm"
) )
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self._prefill_queue_stream_name = self.model_name
logger.info(
f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
)
if self.engine_args.remote_prefill: if self.engine_args.remote_prefill:
if self.engine_args.enable_chunked_prefill is not False:
logger.info("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.preemption_mode != "swap":
logger.info("Preemption mode is not supported yet, setting to swap")
self.engine_args.preemption_mode = "swap"
if self.engine_args.pipeline_parallel_size != 1:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.router == "kv":
raise NotImplementedError( raise NotImplementedError(
"Remote prefill is not supported for aggregated multimodal example" "Multimodal requests are not supported for kv router mode"
) )
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine) signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
...@@ -86,17 +118,40 @@ class VllmWorker: ...@@ -86,17 +118,40 @@ class VllmWorker:
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore if self.do_remote_prefill:
self.encode_worker_client = ( metadata = self.engine_client.nixl_metadata
await runtime.namespace(enc_comp_ns) metadata_store = NixlMetadataStore("dynamo", runtime)
.component(enc_comp_name) await metadata_store.put(metadata.engine_id, metadata)
.endpoint("encode")
.client()
)
await check_required_workers(self.encode_worker_client, self.min_workers) if self.engine_args.conditional_disagg:
self.disaggregated_router = PyDisaggregatedRouter(
runtime,
self.model_name,
max_local_prefill_length=self.engine_args.max_local_prefill_length,
max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
)
await self.disaggregated_router.async_init()
else:
self.disaggregated_router = None
model = LlavaForConditionalGeneration.from_pretrained(
self.engine_args.model
)
vision_tower = model.vision_tower
self.embedding_size = (
vision_tower.vision_model.embeddings.position_embedding.num_embeddings
)
else:
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
.endpoint("encode")
.client()
)
self.disaggregated_router = None await check_required_workers(self.encode_worker_client, self.min_workers)
self.disaggregated_router = None
logger.info("VllmWorker has been initialized") logger.info("VllmWorker has been initialized")
def shutdown_vllm_engine(self, signum, frame): def shutdown_vllm_engine(self, signum, frame):
...@@ -111,34 +166,105 @@ class VllmWorker: ...@@ -111,34 +166,105 @@ class VllmWorker:
finally: finally:
loop.stop() loop.stop()
def get_remote_prefill_request_callback(self):
async def callback(request: RemotePrefillRequest):
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
await prefill_queue.enqueue_prefill_request(request)
return callback
@dynamo_endpoint() @dynamo_endpoint()
async def generate(self, request: vLLMMultimodalRequest): async def generate(self, request: vLLMMultimodalRequest):
image_url = request.image_url image_features = None
if self.do_remote_prefill:
if self.disaggregated_router is not None:
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
prefill_queue_size = await prefill_queue.get_queue_size()
disagg_router_decision = await self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]),
request.prefix_hit_rate,
prefill_queue_size,
)
else:
# always prefill remotely if no disaggregated router is provided
disagg_router_decision = True
encode_generator = await self.encode_worker_client.round_robin( if self.do_remote_prefill and disagg_router_decision:
EncodeRequest( remote_prefill_params = RemotePrefillParams(
image_url=image_url, is_remote_prefill=True,
).model_dump_json() remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
) # Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
multimodal_data_source={
"image_url": request.image_url,
},
)
logger.info(
f"Prefilling remotely for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
else:
remote_prefill_params = None
logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
async for encode_response in encode_generator: # As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size
encode_output = EncodeResponse.model_validate_json(encode_response.data()) # so that decode worker can pre-allocate the memory with the correct size.
image_features = torch.tensor( # The structure of the prompt will be like: "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:".
encode_output.image_features, device=device, dtype=torch.float16 # Since the "<image>" token is included in the prompt, only need to insert (embedding_size - 1) dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
DUMMY_TOKEN_ID = 0
# Find the index of the image token in the prompt token ids
image_token_index = request.engine_prompt["prompt_token_ids"].index(
IMAGE_TOKEN_ID
)
dummy_token_index = image_token_index + 1
prompt_ids = (
request.engine_prompt["prompt_token_ids"][:dummy_token_index]
+ [DUMMY_TOKEN_ID] * (self.embedding_size - 1)
+ request.engine_prompt["prompt_token_ids"][dummy_token_index:]
) )
remote_prefill_params = None else:
logger.info( # For aggregated serving, the vllm worker will directly send the encode request to the encode worker.
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}" encode_generator = await self.encode_worker_client.round_robin(
) EncodeRequest(
image_url=request.image_url,
).model_dump_json()
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data()
)
image_features = torch.tensor(
encode_output.image_features, device=device, dtype=torch.float16
)
remote_prefill_params = None
logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
prompt_ids = request.engine_prompt["prompt_token_ids"]
# rust HTTP requires Delta streaming # rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA request.sampling_params.output_kind = RequestOutputKind.DELTA
if image_features is not None:
multi_modal_data = {"image": image_features}
else:
multi_modal_data = None
async for response in self.engine_client.generate( async for response in self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"], prompt_token_ids=prompt_ids,
multi_modal_data={"image": image_features}, multi_modal_data=multi_modal_data,
), ),
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
request_id=request.request_id, request_id=request.request_id,
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/llava-1.5-7b-hf
block-size: 64
max-model-len: 4096
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
Processor:
router: round-robin
common-configs: [model, block-size]
VllmWorker:
remote-prefill: true
conditional-disagg: true
max-local-prefill-length: 10
max-prefill-queue-size: 2
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config]
PrefillWorker:
max-num-batched-tokens: 16384
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config]
EncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -12,3 +12,11 @@ ...@@ -12,3 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from components.encode_worker import EncodeWorker
from components.frontend import Frontend
from components.prefill_worker import PrefillWorker
from components.processor import Processor
from components.worker import VllmWorker
Frontend.link(Processor).link(VllmWorker).link(PrefillWorker).link(EncodeWorker)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
from typing import ClassVar, Optional
from nats.aio.client import Client as NATS
from nats.errors import Error as NatsError
from nats.js.client import JetStreamContext
from nats.js.errors import NotFoundError
class NATSQueue:
_instance: ClassVar[Optional["NATSQueue"]] = None
_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
def __init__(
self,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
self.nats_url = nats_server
self._nc: Optional[NATS] = None
self._js: Optional[JetStreamContext] = None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self._stream_name = stream_name.replace("/", "_").replace("\\", "_")
self._subject = f"{self._stream_name}.*"
self.dequeue_timeout = dequeue_timeout
self._subscriber: Optional[JetStreamContext.PullSubscription] = None
@classmethod
@asynccontextmanager
async def get_instance(
cls,
*,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
"""Get or create a singleton instance of NATSq"""
# TODO: check if this _lock is needed with GIL
async with cls._lock:
if cls._instance is None:
cls._instance = cls(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
await cls._instance.connect()
try:
yield cls._instance
except Exception:
if cls._instance:
await cls._instance.close()
cls._instance = None
raise
# TODO: check to see if this can be replaced by something like get_instance().close()
@classmethod
async def shutdown(cls):
"""Explicitly close the singleton instance if it exists"""
async with cls._lock:
if cls._instance:
await cls._instance.close()
cls._instance = None
async def connect(self):
"""Establish connection and create stream if needed"""
try:
if self._nc is None:
self._nc = NATS()
await self._nc.connect(self.nats_url)
self._js = self._nc.jetstream()
# Check if stream exists, if not create it
try:
await self._js.stream_info(self._stream_name)
except NotFoundError:
await self._js.add_stream(
name=self._stream_name, subjects=[self._subject]
)
# Create persistent subscriber
self._subscriber = await self._js.pull_subscribe(
f"{self._stream_name}.queue", durable="worker-group"
)
except NatsError as e:
await self.close()
raise ConnectionError(f"Failed to connect to NATS: {e}")
async def ensure_connection(self):
"""Ensure we have an active connection"""
if self._nc is None or self._nc.is_closed:
await self.connect()
async def close(self):
"""Close the connection when done"""
if self._nc:
await self._nc.close()
self._nc = None
self._js = None
self._subscriber = None
# TODO: is enqueue/dequeue_object a better name for a general queue?
async def enqueue_task(self, task_data: bytes) -> None:
"""
Enqueue a task using msgspec-encoded data
"""
await self.ensure_connection()
try:
await self._js.publish(f"{self._stream_name}.queue", task_data) # type: ignore
except NatsError as e:
raise RuntimeError(f"Failed to enqueue task: {e}")
async def dequeue_task(self) -> Optional[bytes]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await self.ensure_connection()
try:
msgs = await self._subscriber.fetch(1, timeout=self.dequeue_timeout) # type: ignore
if msgs:
msg = msgs[0]
await msg.ack()
return msg.data
return None
except asyncio.TimeoutError:
return None
except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}")
async def get_queue_size(self) -> int:
"""Get the number of messages currently in the queue"""
await self.ensure_connection()
try:
# Get consumer info to get pending messages count
consumer_info = await self._js.consumer_info( # type: ignore
self._stream_name, "worker-group"
)
# Return number of pending messages (real-time queue size)
return consumer_info.num_pending
except NatsError as e:
raise RuntimeError(f"Failed to get queue size: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import msgspec
from utils.nats_queue import NATSQueue
from vllm.remote_prefill import RemotePrefillRequest
class PrefillQueue(NATSQueue):
"""
A wrapper of NATSQueue for PrefillRequest.
The stream name is forced to be "prefill_queue".
"""
def __init__(
self,
stream_name="prefill_queue",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
super().__init__(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
async def enqueue_prefill_request(
self, prefill_request: RemotePrefillRequest
) -> None:
encoded_request = msgspec.json.encode(prefill_request)
await self.enqueue_task(encoded_request)
async def dequeue_prefill_request(self) -> Optional[RemotePrefillRequest]:
encoded_request = await self.dequeue_task()
if encoded_request is not None:
prefill_request = msgspec.json.decode(
encoded_request, type=RemotePrefillRequest
)
return prefill_request
else:
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment