Unverified Commit 5a30923f authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Support OAI frontend format and add async image handing for multimodal (#1214)


Co-authored-by: default avatarJ Wyman <jwyman@nvidia.com>
parent 8cc13610
...@@ -55,21 +55,35 @@ dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml ...@@ -55,21 +55,35 @@ dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
In another terminal: In another terminal:
```bash ```bash
curl -X 'POST' \ curl http://localhost:8000/v1/chat/completions \
'http://localhost:8000/generate' \ -H "Content-Type: application/json" \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{ -d '{
"model":"llava-hf/llava-1.5-7b-hf", "model": "llava-hf/llava-1.5-7b-hf",
"image":"http://images.cocodataset.org/test2017/000000155781.jpg", "messages": [
"prompt":"Describe the image", {
"max_tokens":300 "role": "user",
}' | jq "content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}'
``` ```
You should see a response similar to this: You should see a response similar to this:
``` ```
" The image features a close-up view of the front of a bus, with a prominent neon sign clearly displayed. The bus appears to be slightly past its prime condition, beyond its out-of-service section. Inside the bus, we see a depth of text, with the sign saying \"out of service\". A wide array of windows line the side of the double-decker bus, making its overall appearance quite interesting and vintage." {"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
``` ```
## Multimodal Disaggregated serving ## Multimodal Disaggregated serving
...@@ -108,19 +122,33 @@ dynamo serve graphs.disagg:Frontend -f configs/disagg.yaml ...@@ -108,19 +122,33 @@ dynamo serve graphs.disagg:Frontend -f configs/disagg.yaml
In another terminal: In another terminal:
```bash ```bash
curl -X 'POST' \ curl http://localhost:8000/v1/chat/completions \
'http://localhost:8000/generate' \ -H "Content-Type: application/json" \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{ -d '{
"model":"llava-hf/llava-1.5-7b-hf", "model": "llava-hf/llava-1.5-7b-hf",
"image":"http://images.cocodataset.org/val2017/000000324158.jpg", "messages": [
"prompt":"Describe the mood and setting of this image in two sentences. What time of day do you think it is?", {
"max_tokens":300 "role": "user",
}' | jq "content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}'
``` ```
You should see a response similar to this: You should see a response similar to this:
``` ```
" The image depicts a man moving across a field on a skateboard. The setting appears to be joyful, and this activity suggests that the man is enjoying an outdoor adventure. Additionally, a pet dog is probably accompanying, contributing to the positive mood. The mood and setting of the image appear lively and shoal. The sun is most likely low in the sky, as this would produce a nice daylight." {"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]}
``` ```
...@@ -197,13 +197,11 @@ class VllmDecodeWorker: ...@@ -197,13 +197,11 @@ class VllmDecodeWorker:
async def generate(self, request: vLLMMultimodalRequest): async def generate(self, request: vLLMMultimodalRequest):
request_id = request.request_id request_id = request.request_id
image_url = request.image_url image_url = request.image_url
logger.info( logger.info(f"Received multimodal request {{ id: {request_id} }}.")
f"Received multimodal request {{ id: {request_id}, image_url: '{image_url}' }}."
)
embeddings = None embeddings = None
if self.do_remote_prefill: if self.do_remote_prefill:
logger.debug( logger.debug(
f"Disaggregated: request {{ id: {request_id}, image_url: '{image_url}' }}" f"Disaggregated: request {{ id: {request_id} }}"
" prefill worker will populate the decode model's key-value cache ahead of time;" " prefill worker will populate the decode model's key-value cache ahead of time;"
" no direct encode worker interaction required." " no direct encode worker interaction required."
) )
...@@ -224,7 +222,7 @@ class VllmDecodeWorker: ...@@ -224,7 +222,7 @@ class VllmDecodeWorker:
if self.do_remote_prefill and disagg_router_decision: if self.do_remote_prefill and disagg_router_decision:
logger.debug( logger.debug(
f"Prefilling remotely for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}" f"Prefilling remotely for request {{ id: {request_id} }} with length {len(request.engine_prompt['prompt_token_ids'])}"
) )
remote_prefill_params = RemotePrefillParams( remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True, is_remote_prefill=True,
...@@ -237,7 +235,7 @@ class VllmDecodeWorker: ...@@ -237,7 +235,7 @@ class VllmDecodeWorker:
else: else:
remote_prefill_params = None remote_prefill_params = None
logger.debug( logger.debug(
f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}" f"Prefilling locally for request {{ id: {request_id} }} with length {len(request.engine_prompt['prompt_token_ids'])}"
) )
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache. # The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
...@@ -260,7 +258,7 @@ class VllmDecodeWorker: ...@@ -260,7 +258,7 @@ class VllmDecodeWorker:
else: else:
logger.debug( logger.debug(
f"Aggregated: request {{ id: {request_id}, image_url: '{image_url}' }}" f"Aggregated: request {{ id: {request_id} }}"
" no prefill worker available, embeddings directly from encode worker." " no prefill worker available, embeddings directly from encode worker."
) )
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor. # Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
...@@ -295,8 +293,8 @@ class VllmDecodeWorker: ...@@ -295,8 +293,8 @@ class VllmDecodeWorker:
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker. # At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params = None remote_prefill_params = None
logger.info( logger.debug(
f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}" f"Prefilling locally for request {{ id: {request_id} }} with length {len(request.engine_prompt['prompt_token_ids'])}"
) )
prompt_ids = request.engine_prompt["prompt_token_ids"] prompt_ids = request.engine_prompt["prompt_token_ids"]
......
...@@ -13,13 +13,17 @@ ...@@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import base64
import binascii
import logging import logging
from io import BytesIO from io import BytesIO
from queue import Queue from queue import Queue
from typing import AsyncIterator from typing import AsyncIterator, Optional
from urllib.parse import urlparse
import connect import connect
import requests import httpx
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoImageProcessor, LlavaForConditionalGeneration from transformers import AutoImageProcessor, LlavaForConditionalGeneration
...@@ -70,14 +74,83 @@ class VllmEncodeWorker: ...@@ -70,14 +74,83 @@ class VllmEncodeWorker:
self._image_cache: dict[str, Image.Image] = {} self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM) self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
self._http_client: Optional[httpx.AsyncClient] = None
self._http_timeout = 30.0
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
response = await self._http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
image = await asyncio.to_thread(Image.open, image_data)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
self._cache_queue.put(image_url_lower)
return image
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
@endpoint() @endpoint()
async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]: async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
logger.debug( logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
f"Received encode request: {{ id: {request.request_id}, image_url: '{request.image_url}' }}."
)
request_id = request.request_id request_id = request.request_id
image_url = request.image_url.lower()
# The following steps encode the requested image and provided useful embeddings. # The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL. # 1. Open the image from the provided URL.
...@@ -89,64 +162,49 @@ class VllmEncodeWorker: ...@@ -89,64 +162,49 @@ class VllmEncodeWorker:
# 7. Await for the write operation to complete. # 7. Await for the write operation to complete.
# 8. Yield the encode response. # 8. Yield the encode response.
# Either retrieve the image from the cache or download it and then cache it. try:
if request.image_url in self._image_cache: image = await self.load_image(request.image_url)
image = self._image_cache[image_url]
logger.debug(
f"Image found in cache for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
else:
image = self.open_image(image_url)
logger.debug(
f"Downloading/opening image for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
# Cache the image for future use, and evict the oldest image if the cache is full.
if self._cache_queue.full():
oldest_image_url = self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[request.image_url] = image
self._cache_queue.put(request.image_url)
logger.debug(
f"Processing image for request: {{ id: {request_id}, image_url: '{image_url}' }}"
)
image_embeds = self.image_processor(images=image, return_tensors="pt")
with torch.no_grad(): logger.debug(f"Processing image for request: {{ id: {request_id} }}")
logger.debug(f"Vision model device: {self.vision_model.device}") image_embeds = self.image_processor(images=image, return_tensors="pt")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
logger.debug("Vision model completed.")
embeddings = vision_outputs.last_hidden_state with torch.no_grad():
embeddings = self.vision_model.multi_modal_projector(embeddings) logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
logger.debug("Vision model completed.")
logger.debug( embeddings = vision_outputs.last_hidden_state
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}." embeddings = self.vision_model.multi_modal_projector(embeddings)
)
if request.serialized_request is None: logger.debug(
logger.error( f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
f"Request serialized_request is None for request: {{ id: {request_id}, image_url: '{image_url}' }}."
) )
# Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime). if request.serialized_request is None:
descriptor = connect.Descriptor(embeddings) logger.error(
# Create a write operation using the serialized request and the descriptor. f"Request serialized_request is None for request: {{ id: {request_id} }}."
# This will begin the RDMA transfer of the embeddings to the remote worker. )
write_op = await self._connector.begin_write(
descriptor, # Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
request.serialized_request, descriptor = connect.Descriptor(embeddings)
) # Create a write operation using the serialized request and the descriptor.
# Await for the write operation to complete. # This will begin the RDMA transfer of the embeddings to the remote worker.
# This will block until the data has been written to the remote worker or an error occurs. write_op = await self._connector.begin_write(
await write_op.wait_for_completion() descriptor,
request.serialized_request,
yield EncodeResponse( )
request_id=request.request_id, # Await for the write operation to complete.
).model_dump_json() # This will block until the data has been written to the remote worker or an error occurs.
await write_op.wait_for_completion()
yield EncodeResponse(
request_id=request.request_id,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
raise
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
...@@ -155,19 +213,6 @@ class VllmEncodeWorker: ...@@ -155,19 +213,6 @@ class VllmEncodeWorker:
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize() await self._connector.initialize()
# Initialize HTTP client with default limits
self._http_client = httpx.AsyncClient(timeout=self._http_timeout)
logger.info("Startup completed.") logger.info("Startup completed.")
def open_image(self, image: str) -> Image.Image:
# TODO: Have a seperate field for url and non url - and avoid auto detection
try:
# Acquire the image and convert it to the format (RGB) the image processor model expects.
if image.startswith("http") or image.startswith("https"):
response = requests.get(image)
image_data = Image.open(BytesIO(response.content)).convert("RGB")
else:
image_data = Image.open(image).convert("RGB")
return image_data
except Exception as e:
logger.error(f"Error opening image: {e}")
raise e
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
from components.processor import Processor from components.processor import Processor
...@@ -37,10 +38,14 @@ logger = logging.getLogger(__name__) ...@@ -37,10 +38,14 @@ logger = logging.getLogger(__name__)
class Frontend: class Frontend:
processor = depends(Processor) processor = depends(Processor)
@api() @api(name="v1/chat/completions")
async def generate(self, request: MultiModalRequest): async def generate(self, request: MultiModalRequest):
async def content_generator(): async def content_generator():
async for response in self.processor.generate(request.model_dump_json()): async for response in self.processor.generate(request.model_dump_json()):
yield response try:
s = json.loads(response)
yield s
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to parse JSON response: {e}")
return StreamingResponse(content_generator()) return StreamingResponse(content_generator(), media_type="text/event-stream")
...@@ -191,7 +191,7 @@ class VllmPrefillWorker: ...@@ -191,7 +191,7 @@ class VllmPrefillWorker:
image_url = request.multimodal_data_source["image_url"] image_url = request.multimodal_data_source["image_url"]
logger.info( logger.info(
f"Received prefill request {{ id: {request_id}, engine_id: {engine_id}, image_url: '{image_url}' }}." f"Received prefill request {{ id: {request_id}, engine_id: {engine_id} }}."
) )
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor. # Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
......
...@@ -145,20 +145,10 @@ class Processor(ProcessMixIn): ...@@ -145,20 +145,10 @@ class Processor(ProcessMixIn):
output = self._generate_responses(response_generator, request_type) output = self._generate_responses(response_generator, request_type)
# TODO: This is a temporary solution to combine the content from the engine generator.
# After having the multimodal support in OpenAI compatible frontend, we can use that directly without the need to manually combine the content.
combined_content = ""
async for response in await self._stream_response( async for response in await self._stream_response(
request, output, request_id, conversation request, output, request_id, conversation
): ):
if "choices" in response and len(response["choices"]) > 0: yield response
delta = response["choices"][0].get("delta", {})
content = delta.get("content", "")
combined_content += content
# Yield complete content on final response
if response["choices"][0].get("finish_reason") is not None:
yield combined_content
# This method is used to process the responses from the engine generator. # This method is used to process the responses from the engine generator.
async def _generate_responses( async def _generate_responses(
...@@ -196,22 +186,29 @@ class Processor(ProcessMixIn): ...@@ -196,22 +186,29 @@ class Processor(ProcessMixIn):
# The generate endpoint will be used by the frontend to handle incoming requests. # The generate endpoint will be used by the frontend to handle incoming requests.
@endpoint() @endpoint()
async def generate(self, request: MultiModalRequest): async def generate(self, raw_request: MultiModalRequest):
# TODO: After having the multimodal support in OpenAI compatible frontend, we can use that directly and remove the custom endpoint.
msg = { msg = {
"role": "user", "role": "user",
"content": "USER: <image>\nQuestion:" + request.prompt + " Answer:", "content": "USER: <image>\nQuestion:"
+ raw_request.messages[0].content[0].text
+ " Answer:",
} }
chat_request = ChatCompletionRequest( chat_request = ChatCompletionRequest(
model=request.model, model=raw_request.model,
messages=[msg], messages=[msg],
stream=True, stream=raw_request.stream,
max_tokens=request.max_tokens, max_tokens=raw_request.max_tokens,
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
) )
image_url = None
async for response in self._generate( for message in raw_request.messages:
chat_request, request.image, RequestType.CHAT for item in message.content:
): if item.type == "image_url":
image_url = item.image_url.url
if image_url is None:
raise ValueError("Image URL is required")
async for response in self._generate(chat_request, image_url, RequestType.CHAT):
yield json.dumps(response) yield json.dumps(response)
...@@ -174,21 +174,65 @@ class ChatProcessor: ...@@ -174,21 +174,65 @@ class ChatProcessor:
conversation: List, conversation: List,
): ):
request_metadata = RequestResponseMetadata(request_id=request_id) request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream: if request.stream:
raise ValueError("Only streaming responses are supported") # Handle streaming response
async for raw_response in self.openai_serving.chat_completion_stream_generator( async for raw_response in self.openai_serving.chat_completion_stream_generator(
request, request,
result_generator, result_generator,
request_id, request_id,
request.model, request.model,
conversation, conversation,
self.tokenizer, self.tokenizer,
request_metadata, request_metadata,
): ):
if raw_response.startswith("data: [DONE]"): yield raw_response
break else:
response = json.loads(raw_response.lstrip("data: ")) # Handle non-streaming response
yield response # Collect all chunks into a single response
full_response = None
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
if full_response is None:
# Initialize the full response structure
full_response = {
"id": response.get("id", ""),
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": response.get("index", 0),
"message": {"role": "assistant", "content": ""},
"finish_reason": None,
}
],
}
# Concatenate content if it exists
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
full_response["choices"][0]["message"]["content"] += content
# Update finish reason if present
if "finish_reason" in response["choices"][0]:
full_response["choices"][0]["finish_reason"] = response[
"choices"
][0]["finish_reason"]
if full_response is not None:
yield json.dumps(full_response)
class CompletionsProcessor: class CompletionsProcessor:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import json import json
from typing import Any, List, Optional from typing import Any, List, Literal, Optional, Union
import connect import connect
import msgspec import msgspec
...@@ -92,12 +92,34 @@ class vLLMGenerateRequest(BaseModel): ...@@ -92,12 +92,34 @@ class vLLMGenerateRequest(BaseModel):
) )
class TextContent(BaseModel):
type: Literal["text"]
text: str
class ImageURLDetail(BaseModel):
url: str
class ImageContent(BaseModel):
type: Literal["image_url"]
image_url: ImageURLDetail
MessageContent = Union[TextContent, ImageContent]
class ChatMessage(BaseModel):
role: Literal["user", "system", "assistant"]
content: List[MessageContent]
class MultiModalRequest(BaseModel): class MultiModalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
model: str model: str
image: str messages: List[ChatMessage]
max_tokens: int max_tokens: Optional[int] = None
prompt: str stream: Optional[bool] = True
class vLLMMultimodalRequest(vLLMGenerateRequest): class vLLMMultimodalRequest(vLLMGenerateRequest):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment