Unverified Commit 58df5aca authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add multimodal example with aggregated serving (#709)

parent f122aa4e
...@@ -24,3 +24,4 @@ pytest-mypy ...@@ -24,3 +24,4 @@ pytest-mypy
pytest-timeout pytest-timeout
# add types library stub for PyYAML # add types library stub for PyYAML
types-PyYAML types-PyYAML
types-requests
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
accelerate==1.6.0
fastapi==0.115.6 fastapi==0.115.6
ftfy ftfy
grpcio-tools==1.66.0 grpcio-tools==1.66.0
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Multimodal Deployment Examples
This directory contains examples and reference implementations for deploying a multimodal model with Dynamo.
## Components
- workers: For aggregated serving, we have two workers, [encode_worker](components/encode_worker.py) for encoding and [vllm_worker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the vllm worker.
- frontend: Http endpoint to handle incoming requests.
#### Multimodal Aggregated serving
In this deployment, we have two workers, [encode_worker](components/encode_worker.py) and [vllm_worker](components/worker.py).
The encode worker is responsible for encoding the image and passing the embeddings to the vllm worker via NATS.
The vllm worker then prefills and decodes the prompt, just like the [LLM aggregated serving](../llm/README.md) example.
By separating the encode from the prefill and decode stages, we can have a more flexible deployment and scale the
encode worker independently from the prefill and decode workers if needed.
This figure shows the flow of the deployment:
```
+------+ +-----------+ +------------------+ image url +---------------+
| HTTP |----->| processor |----->| vllm worker |--------------------->| encode worker |
| |<-----| |<-----| |<---------------------| |
+------+ +-----------+ +------------------+ image embeddings +---------------+
```
```bash
cd $DYNAMO_HOME/examples/multimodal
dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
```
### Client
In another terminal:
```bash
curl -X 'POST' \
'http://localhost:8000/generate' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"model":"llava-hf/llava-1.5-7b-hf",
"image":"http://images.cocodataset.org/test2017/000000155781.jpg",
"prompt":"Describe the image",
"max_tokens":300
}' | jq
```
You should see a response similar to this:
```
" The image features a close-up view of the front of a bus, with a prominent neon sign clearly displayed. The bus appears to be slightly past its prime condition, beyond its out-of-service section. Inside the bus, we see a depth of text, with the sign saying \"out of service\". A wide array of windows line the side of the double-decker bus, making its overall appearance quite interesting and vintage."
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from io import BytesIO
from typing import AsyncIterator
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args
from dynamo.sdk import dynamo_endpoint, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class EncodeWorker:
def __init__(self) -> None:
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.MODEL_ID = self.engine_args.model
self.image_processor = AutoImageProcessor.from_pretrained(
self.MODEL_ID, trust_remote_code=True
)
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
).eval()
@dynamo_endpoint()
async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
image = self.open_image(request.image_url)
image_embeds = self.image_processor(images=image, return_tensors="pt")
with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
image_features = vision_outputs.last_hidden_state
image_features = self.vision_model.multi_modal_projector(image_features)
yield EncodeResponse(
image_features=image_features.tolist()
).model_dump_json()
def open_image(self, image: str) -> Image.Image:
# TODO: Have a seperate field for url and non url - and avoid auto detection
try:
if image.startswith("http") or image.startswith("https"):
response = requests.get(image)
image_data = Image.open(BytesIO(response.content)).convert("RGB")
else:
image_data = Image.open(image).convert("RGB")
except Exception as e:
logger.error(f"Error opening image: {e}")
raise e
return image_data
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from components.processor import Processor
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from utils.protocol import MultiModalRequest
from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_endpoint, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
app=FastAPI(title="Multimodal Example"),
)
class Frontend:
processor = depends(Processor)
@dynamo_endpoint(is_api=True)
async def generate(self, request: MultiModalRequest):
async def content_generator():
async for response in self.processor.generate(request.model_dump_json()):
yield response
return StreamingResponse(content_generator())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
from components.worker import VllmWorker
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers
from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
logger = logging.getLogger(__name__)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
worker = depends(VllmWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.min_workers = 1
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
await check_required_workers(self.worker_client, self.min_workers)
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/processor/",
{"router": self.engine_args.router},
)
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
image: str,
request_type: RequestType,
):
request_id = str(uuid.uuid4())
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
image_url=image,
)
router_mode = (await self.etcd_kv_cache.get("router")).decode()
if router_mode == "kv":
# The current KV router does not support multimodal requests because
# it performs cache lookup based solely on prompt tokens. At this stage,
# multimodal data (e.g., image features) is not yet available, so the router
# cannot select the optimal worker using both prompt and image inputs.
raise NotImplementedError(
"Multimodal requests are not supported for kv router mode"
)
if router_mode == "random":
response_generator = await self.worker_client.generate(
worker_request.model_dump_json()
)
elif router_mode == "round-robin":
response_generator = await self.worker_client.round_robin(
worker_request.model_dump_json()
)
else:
raise NotImplementedError(f"Router mode {router_mode} not implemented")
output = self._generate_responses(response_generator, request_type)
# TODO: This is a temporary solution to combine the content from the engine generator.
# After having the multimodal support in OpenAI compatible frontend, we can use that directly without the need to manually combine the content.
combined_content = ""
async for response in await self._stream_response(
request, output, request_id, conversation
):
if "choices" in response and len(response["choices"]) > 0:
delta = response["choices"][0].get("delta", {})
content = delta.get("content", "")
combined_content += content
# Yield complete content on final response
if response["choices"][0].get("finish_reason") is not None:
yield combined_content
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
prompt_idx = 0
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
elif request_type == RequestType.COMPLETION:
# Completion requests can have multiple prompts and stream generator requires the prompt index
yield (prompt_idx, request_output)
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
@dynamo_endpoint()
async def generate(self, request: MultiModalRequest):
# TODO: After having the multimodal support in OpenAI compatible frontend, we can use that directly and remove the custom endpoint.
msg = {
"role": "user",
"content": "USER: <image>\nQuestion:" + request.prompt + " Answer:",
}
chat_request = ChatCompletionRequest(
model=request.model,
messages=[msg],
stream=True,
max_tokens=request.max_tokens,
request_id=str(uuid.uuid4()),
)
async for response in self._generate(
chat_request, request.image, RequestType.CHAT
):
yield json.dumps(response)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import signal
import torch
from components.encode_worker import EncodeWorker
from utils.logging import check_required_workers
from utils.protocol import (
EncodeRequest,
EncodeResponse,
MyRequestOutput,
vLLMMultimodalRequest,
)
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import RequestOutputKind
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmWorker:
encode_worker = depends(EncodeWorker)
def __init__(self):
self.client = None
self.min_workers = 1
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.do_remote_prefill = self.engine_args.remote_prefill
self.model_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
if self.engine_args.remote_prefill:
raise NotImplementedError(
"Remote prefill is not supported for aggregated multimodal example"
)
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
@async_on_start
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
if self.engine_args.router == "kv":
raise NotImplementedError(
"Multimodal requests are not supported for kv router mode"
)
runtime = dynamo_context["runtime"]
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
.endpoint("encode")
.client()
)
await check_required_workers(self.encode_worker_client, self.min_workers)
self.disaggregated_router = None
logger.info("VllmWorker has been initialized")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("VllmWorker shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
@dynamo_endpoint()
async def generate(self, request: vLLMMultimodalRequest):
image_url = request.image_url
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
image_url=image_url,
).model_dump_json()
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(encode_response.data())
image_features = torch.tensor(
encode_output.image_features, device=device, dtype=torch.float16
)
remote_prefill_params = None
logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
# rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA
async for response in self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
multi_modal_data={"image": image_features},
),
sampling_params=request.sampling_params,
request_id=request.request_id,
remote_prefill_params=remote_prefill_params,
):
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/llava-1.5-7b-hf
block-size: 64
max-model-len: 4096
Processor:
router: round-robin
common-configs: [model, block-size, max-model-len]
VllmWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len]
EncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.encode_worker import EncodeWorker
from components.frontend import Frontend
from components.processor import Processor
from components.worker import VllmWorker
Frontend.link(Processor).link(VllmWorker).link(EncodeWorker)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
# TODO: Revisit this later when adding multi-modal support for the frontend.
# If no chat template is provided and tokenizer doesn't have one,
# use a simple format that just concatenates messages
if not request.chat_template and not self.tokenizer.chat_template:
chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}Assistant:"
else:
chat_template = request.chat_template or self.tokenizer.chat_template
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from dynamo._core import Client
logger = logging.getLogger(__name__)
async def check_required_workers(
workers_client: Client, required_workers: int, on_change=True, poll_interval=0.5
):
"""Wait until the minimum number of workers are ready."""
worker_ids = workers_client.endpoint_ids()
num_workers = len(worker_ids)
while num_workers < required_workers:
await asyncio.sleep(poll_interval)
worker_ids = workers_client.endpoint_ids()
new_count = len(worker_ids)
if (not on_change) or new_count != num_workers:
logger.info(
f"Waiting for more workers to be ready.\n"
f" Current: {new_count},"
f" Required: {required_workers}"
)
num_workers = new_count
print(f"Workers ready: {worker_ids}")
return worker_ids
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from contextlib import contextmanager
import msgspec
from vllm.distributed.device_communicators.nixl import NixlMetadata
from dynamo.runtime import DistributedRuntime
METADATA_DIR = "/tmp/nixl"
logger = logging.getLogger(__name__)
@contextmanager
def temp_metadata_file(engine_id, metadata: NixlMetadata):
os.makedirs(METADATA_DIR, exist_ok=True)
path = f"{METADATA_DIR}/{engine_id}.nixl_meta"
with open(path, "wb") as f:
encoded = msgspec.msgpack.encode(metadata)
logger.info(f"Size of encoded metadata: {len(encoded)}")
f.write(encoded)
try:
yield path
finally:
if os.path.exists(path):
os.remove(path)
def find_remote_metadata(engine_id):
# find and load metadata from METADATA_DIR that do not match engine_id
remote_metadata = []
for file in os.listdir(METADATA_DIR):
if file.endswith(".nixl_meta"):
if file.split(".")[0] != engine_id:
with open(os.path.join(METADATA_DIR, file), "rb") as f:
remote_metadata.append(
msgspec.msgpack.decode(f.read(), type=NixlMetadata)
)
return remote_metadata
class NixlMetadataStore:
NIXL_METADATA_KEY = "nixl_metadata"
def __init__(self, namespace: str, runtime: DistributedRuntime) -> None:
self._namespace = namespace
# TODO Remove metadata from etcd on delete
self._stored: set[str] = set()
self._cached: dict[str, NixlMetadata] = {}
self._client = runtime.etcd_client()
if self._client is None:
raise Exception("Cannot be used with static workers")
self._key_prefix = f"{self._namespace}/{NixlMetadataStore.NIXL_METADATA_KEY}"
async def put(self, engine_id, metadata: NixlMetadata):
serialized_metadata = msgspec.msgpack.encode(metadata)
key = "/".join([self._key_prefix, engine_id])
await self._client.kv_put(key, serialized_metadata, None)
self._stored.add(engine_id)
async def get(self, engine_id) -> NixlMetadata:
try:
if engine_id in self._cached:
return self._cached[engine_id]
key = "/".join([self._key_prefix, engine_id])
key_values = await self._client.kv_get_prefix(key)
deserialized_metadata = None
for item in key_values:
deserialized_metadata = msgspec.msgpack.decode(
item["value"], type=NixlMetadata
)
break
if deserialized_metadata is None:
raise Exception("metadata not found in etcd")
self._cached[engine_id] = deserialized_metadata
# TODO watch for changes and update cache
# self._client.add_watch_callback(
# key,
# self._watch_callback,
# )
except Exception as e:
raise Exception("Error retrieving metadata for engine {engine_id}") from e
return deserialized_metadata
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Optional
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics
class Request(BaseModel):
prompt: str
sampling_params: dict
class Tokens(BaseModel):
tokens: list[int]
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
prefix_hit_rate: Optional[float] = 0.0
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)}
)
class MultiModalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model: str
image: str
max_tokens: int
prompt: str
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str
class EncodeRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str
class EncodeResponse(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
image_features: List[List[List[float]]]
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: rename to avoid ambiguity with vllm package
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
from dynamo.sdk.lib.config import ServiceConfig
def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
config = ServiceConfig.get_instance()
vllm_args = config.as_args(service_name, prefix=prefix)
parser = FlexibleArgumentParser()
parser.add_argument(
"--router",
type=str,
choices=["random", "round-robin", "kv"],
default="random",
help="Router type to use for scheduling requests to workers",
)
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
parser.add_argument(
"--conditional-disagg",
action="store_true",
help="Use disaggregated router to decide whether to prefill locally or remotely",
)
parser.add_argument(
"--max-local-prefill-length",
type=int,
default=1000,
help="Maximum length for local prefill. If remote prefill is enabled and the prefill length is greater than this value the request will be sent for remote prefill, otherwise prefill phase will run locally.",
)
parser.add_argument(
"--max-prefill-queue-size",
type=int,
default=3,
help="Maximum queue size for remote prefill. If the prefill queue size is greater than this value, prefill phase of the incoming request will be executed locally.",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.router = args.router
engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_prefill_queue_size = args.max_prefill_queue_size
return engine_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment