Unverified Commit e924a7c7 authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

feat: generalize VLM embedding extraction (#1388)


Signed-off-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
Co-authored-by: default avatarKris Hung <krish@nvidia.com>
parent c43ebd24
...@@ -18,7 +18,6 @@ limitations under the License. ...@@ -18,7 +18,6 @@ limitations under the License.
# Multimodal Deployment Examples # Multimodal Deployment Examples
This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo. This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo.
The examples are based on the [llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model.
## Use the Latest Release ## Use the Latest Release
...@@ -59,11 +58,15 @@ flowchart LR ...@@ -59,11 +58,15 @@ flowchart LR
decode_worker --image_url--> encode_worker decode_worker --image_url--> encode_worker
encode_worker --embeddings--> decode_worker encode_worker --embeddings--> decode_worker
``` ```
```
```bash ```bash
cd $DYNAMO_HOME/examples/multimodal cd $DYNAMO_HOME/examples/multimodal
dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml # Serve a LLaVA 1.5 7B model:
dynamo serve graphs.agg:Frontend -f ./configs/agg-llava.yaml
# Serve a Qwen2.5-VL model:
# dynamo serve graphs.agg:Frontend -f ./configs/agg-qwen.yaml
# Serve a Phi3V model:
# dynamo serve graphs.agg:Frontend -f ./configs/agg-phi3v.yaml
``` ```
### Client ### Client
...@@ -92,10 +95,13 @@ curl http://localhost:8000/v1/chat/completions \ ...@@ -92,10 +95,13 @@ curl http://localhost:8000/v1/chat/completions \
} }
], ],
"max_tokens": 300, "max_tokens": 300,
"temperature": 0.0,
"stream": false "stream": false
}' }'
``` ```
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
You should see a response similar to this: You should see a response similar to this:
```json ```json
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]} {"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
...@@ -162,6 +168,7 @@ curl http://localhost:8000/v1/chat/completions \ ...@@ -162,6 +168,7 @@ curl http://localhost:8000/v1/chat/completions \
} }
], ],
"max_tokens": 300, "max_tokens": 300,
"temperature": 0.0,
"stream": false "stream": false
}' }'
``` ```
...@@ -171,6 +178,8 @@ You should see a response similar to this: ...@@ -171,6 +178,8 @@ You should see a response similar to this:
{"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]} {"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]}
``` ```
***Note***: disaggregation is currently only confirmed to work with LLaVA. Qwen VL and PhiV are not confirmed to be supported.
## Deployment with Dynamo Operator ## Deployment with Dynamo Operator
These multimodal examples can be deployed to a Kubernetes cluster using [Dynamo Cloud](../../docs/guides/dynamo_deploy/dynamo_cloud.md) and the Dynamo CLI. These multimodal examples can be deployed to a Kubernetes cluster using [Dynamo Cloud](../../docs/guides/dynamo_deploy/dynamo_cloud.md) and the Dynamo CLI.
...@@ -206,8 +215,12 @@ DYNAMO_TAG=$(dynamo build graphs.agg:Frontend | grep "Successfully built" | awk ...@@ -206,8 +215,12 @@ DYNAMO_TAG=$(dynamo build graphs.agg:Frontend | grep "Successfully built" | awk
# Deploy to Kubernetes # Deploy to Kubernetes
export DEPLOYMENT_NAME=multimodal-agg export DEPLOYMENT_NAME=multimodal-agg
# For aggregated serving: # For aggregated serving with LLaVA:
dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg.yaml dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-llava.yaml
# For aggregated serving with Qwen2.5-VL:
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-qwen.yaml
# For aggregated serving with Phi3V:
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-phi3v.yaml
# For disaggregated serving: # For disaggregated serving:
# export DEPLOYMENT_NAME=multimodal-disagg # export DEPLOYMENT_NAME=multimodal-disagg
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/disagg.yaml # dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/disagg.yaml
...@@ -244,8 +257,11 @@ curl localhost:8000/v1/chat/completions \ ...@@ -244,8 +257,11 @@ curl localhost:8000/v1/chat/completions \
} }
], ],
"max_tokens": 300, "max_tokens": 300,
"temperature": 0.0,
"stream": false "stream": false
}' }'
``` ```
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
For more details on managing deployments, testing, and troubleshooting, please refer to the [Operator Deployment Guide](../../docs/guides/dynamo_deploy/operator_deployment.md). For more details on managing deployments, testing, and troubleshooting, please refer to the [Operator Deployment Guide](../../docs/guides/dynamo_deploy/operator_deployment.md).
...@@ -24,8 +24,8 @@ import torch ...@@ -24,8 +24,8 @@ import torch
from components.disagg_router import PyDisaggregatedRouter from components.disagg_router import PyDisaggregatedRouter
from components.encode_worker import VllmEncodeWorker from components.encode_worker import VllmEncodeWorker
from components.prefill_worker import VllmPrefillWorker from components.prefill_worker import VllmPrefillWorker
from transformers import LlavaForConditionalGeneration
from utils.logging import check_required_workers from utils.logging import check_required_workers
from utils.model import construct_mm_data, get_vision_embeddings_info
from utils.nixl import NixlMetadataStore from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue from utils.prefill_queue import PrefillQueue
from utils.protocol import ( from utils.protocol import (
...@@ -117,6 +117,11 @@ class VllmDecodeWorker: ...@@ -117,6 +117,11 @@ class VllmDecodeWorker:
) )
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
self.engine_args.model, self.engine_args.num_patches
)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]
if self.do_remote_prefill: if self.do_remote_prefill:
metadata = self.engine_client.nixl_metadata metadata = self.engine_client.nixl_metadata
...@@ -133,18 +138,7 @@ class VllmDecodeWorker: ...@@ -133,18 +138,7 @@ class VllmDecodeWorker:
await self.disaggregated_router.async_init() await self.disaggregated_router.async_init()
else: else:
self.disaggregated_router = None self.disaggregated_router = None
model = LlavaForConditionalGeneration.from_pretrained(
self.engine_args.model,
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()
vision_tower = model.vision_tower
self.embedding_size = (
vision_tower.vision_model.embeddings.position_embedding.num_embeddings
)
else: else:
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16 EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda" EMBEDDINGS_DEVICE = "cuda"
...@@ -161,7 +155,7 @@ class VllmDecodeWorker: ...@@ -161,7 +155,7 @@ class VllmDecodeWorker:
# Create a longer-lived buffer for receiving the image embeddings. # Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty( embeddings = torch.empty(
EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
) )
descriptor = connect.Descriptor(embeddings) descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically). # Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
...@@ -206,13 +200,15 @@ class VllmDecodeWorker: ...@@ -206,13 +200,15 @@ class VllmDecodeWorker:
multi_modal_data, multi_modal_data,
remote_prefill_params, remote_prefill_params,
) = await self.remote_prefill(request) ) = await self.remote_prefill(request)
else: else:
( (
prompt_ids, prompt_ids,
multi_modal_data, multi_modal_data,
remote_prefill_params, remote_prefill_params,
) = await self.local_prefill(request) ) = await self.local_prefill(request)
logger.debug(f"Prompt ids: {prompt_ids}")
logger.debug(f"Multi modal data: {multi_modal_data}")
logger.debug(f"Remote prefill params: {remote_prefill_params}")
# rust HTTP requires Delta streaming # rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA request.sampling_params.output_kind = RequestOutputKind.DELTA
...@@ -227,7 +223,7 @@ class VllmDecodeWorker: ...@@ -227,7 +223,7 @@ class VllmDecodeWorker:
remote_prefill_params=remote_prefill_params, remote_prefill_params=remote_prefill_params,
): ):
logger.debug( logger.debug(
f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}" f"Yielding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
) )
yield MyRequestOutput( yield MyRequestOutput(
request_id=response.request_id, request_id=response.request_id,
...@@ -294,7 +290,9 @@ class VllmDecodeWorker: ...@@ -294,7 +290,9 @@ class VllmDecodeWorker:
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model." "Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
) )
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker. # When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
multi_modal_data = {"image": embeddings} multi_modal_data = construct_mm_data(
self.engine_args.model, encode_output, embeddings, self.embeddings_dtype
)
return prompt_ids, multi_modal_data, remote_prefill_params return prompt_ids, multi_modal_data, remote_prefill_params
...@@ -353,17 +351,16 @@ class VllmDecodeWorker: ...@@ -353,17 +351,16 @@ class VllmDecodeWorker:
# As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size # As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size
# so that decode worker can pre-allocate the memory with the correct size. # so that decode worker can pre-allocate the memory with the correct size.
# The structure of the prompt will be like: "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:". # The structure of the prompt will be like: "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:".
# Since the "<image>" token is included in the prompt, only need to insert (embedding_size - 1) dummy tokens after the image token. # Since the "<image>" token is included in the prompt, only need to insert embedding_size dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
DUMMY_TOKEN_ID = 0 DUMMY_TOKEN_ID = 0
# Find the index of the image token in the prompt token ids # Find the index of the image token in the prompt token ids
image_token_index = request.engine_prompt["prompt_token_ids"].index( image_token_index = request.engine_prompt["prompt_token_ids"].index(
IMAGE_TOKEN_ID self.engine_args.image_token_id
) )
dummy_token_index = image_token_index + 1 dummy_token_index = image_token_index + 1
prompt_ids = ( prompt_ids = (
request.engine_prompt["prompt_token_ids"][:dummy_token_index] request.engine_prompt["prompt_token_ids"][:dummy_token_index]
+ [DUMMY_TOKEN_ID] * (self.embedding_size - 1) + [DUMMY_TOKEN_ID] * self.embedding_size
+ request.engine_prompt["prompt_token_ids"][dummy_token_index:] + request.engine_prompt["prompt_token_ids"][dummy_token_index:]
) )
logger.debug( logger.debug(
......
...@@ -26,7 +26,8 @@ import connect ...@@ -26,7 +26,8 @@ import connect
import httpx import httpx
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoImageProcessor, LlavaForConditionalGeneration from transformers import AutoImageProcessor
from utils.model import load_vision_model
from utils.protocol import EncodeRequest, EncodeResponse from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args from utils.vllm import parse_vllm_args
...@@ -66,10 +67,7 @@ class VllmEncodeWorker: ...@@ -66,10 +67,7 @@ class VllmEncodeWorker:
self.image_processor = AutoImageProcessor.from_pretrained( self.image_processor = AutoImageProcessor.from_pretrained(
self.MODEL_ID, trust_remote_code=True self.MODEL_ID, trust_remote_code=True
) )
self.vision_model = load_vision_model(self.MODEL_ID)
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
).eval()
self._image_cache: dict[str, Image.Image] = {} self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM) self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
...@@ -167,17 +165,32 @@ class VllmEncodeWorker: ...@@ -167,17 +165,32 @@ class VllmEncodeWorker:
logger.debug(f"Processing image for request: {{ id: {request_id} }}") logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt") image_embeds = self.image_processor(images=image, return_tensors="pt")
# Add a batch dimension to everything
for item in image_embeds:
image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
logger.debug(f"Image embeds: {image_embeds}")
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
image_sizes = (
image_embeds["image_sizes"].tolist()
if "image_sizes" in image_embeds
else [image.size]
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
with torch.no_grad(): with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}") embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
vision_outputs = self.vision_model.vision_tower( if isinstance(embeddings, tuple) or isinstance(embeddings, list):
image_embeds["pixel_values"].to(self.vision_model.device) # The result multimodal_embeddings may be a list or tuple of tensors, with each
) # tensor corresponding to a multimodal data item (image or video).
logger.debug("Vision model completed.") # TODO: for multi-image support, this result will contain multiple tensors.
embeddings = embeddings[0].unsqueeze(0)
embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)
logger.debug( logger.debug(
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}." f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
) )
...@@ -201,6 +214,8 @@ class VllmEncodeWorker: ...@@ -201,6 +214,8 @@ class VllmEncodeWorker:
yield EncodeResponse( yield EncodeResponse(
request_id=request.request_id, request_id=request.request_id,
image_grid_thw=image_grid_thw,
image_sizes=image_sizes,
).model_dump_json() ).model_dump_json()
except Exception as e: except Exception as e:
logger.error(f"Error processing request {request_id}: {e}") logger.error(f"Error processing request {request_id}: {e}")
......
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
from components.encode_worker import VllmEncodeWorker from components.encode_worker import VllmEncodeWorker
from pydantic import BaseModel from pydantic import BaseModel
from utils.logging import check_required_workers from utils.logging import check_required_workers
from utils.model import construct_mm_data, get_vision_embeddings_info
from utils.nixl import NixlMetadataStore from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue from utils.prefill_queue import PrefillQueue
from utils.protocol import EncodeRequest, EncodeResponse from utils.protocol import EncodeRequest, EncodeResponse
...@@ -39,9 +40,6 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic ...@@ -39,9 +40,6 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda" EMBEDDINGS_DEVICE = "cuda"
...@@ -113,9 +111,12 @@ class VllmPrefillWorker: ...@@ -113,9 +111,12 @@ class VllmPrefillWorker:
await self._connector.initialize() await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings. # Create a longer-lived buffer for receiving the image embeddings.
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
self.engine_args.model, self.engine_args.num_patches
)
embeddings = torch.empty( embeddings = torch.empty(
EMBEDDINGS_SHAPE, embeddings_shape,
dtype=EMBEDDINGS_DTYPE, dtype=self.embeddings_dtype,
device=EMBEDDINGS_DEVICE, device=EMBEDDINGS_DEVICE,
) )
descriptor = connect.Descriptor(embeddings) descriptor = connect.Descriptor(embeddings)
...@@ -248,10 +249,11 @@ class VllmPrefillWorker: ...@@ -248,10 +249,11 @@ class VllmPrefillWorker:
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache, # To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens are inserted based on the embedding size in the worker.py. # some placeholder dummy tokens are inserted based on the embedding size in the worker.py.
# TODO: make this more flexible/model-dependent # TODO: make this more flexible/model-dependent
IMAGE_TOKEN_ID = 32000
embedding_size = embeddings.shape[1] embedding_size = embeddings.shape[1]
padding_size = embedding_size - 1 padding_size = embedding_size
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID) image_token_index = request.prompt_token_ids.index(
self.engine_args.image_token_id
)
dummy_token_index = image_token_index + 1 dummy_token_index = image_token_index + 1
prompt_token_ids = ( prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index] request.prompt_token_ids[:dummy_token_index]
...@@ -262,7 +264,12 @@ class VllmPrefillWorker: ...@@ -262,7 +264,12 @@ class VllmPrefillWorker:
request_id=request_id, request_id=request_id,
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_data={"image": embeddings}, multi_modal_data=construct_mm_data(
self.engine_args.model,
encode_output,
embeddings,
self.embeddings_dtype,
),
), ),
sampling_params=sampling_params, sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params, remote_prefill_params=remote_prefill_params,
......
...@@ -188,9 +188,19 @@ class Processor(ProcessMixIn): ...@@ -188,9 +188,19 @@ class Processor(ProcessMixIn):
# The generate endpoint will be used by the frontend to handle incoming requests. # The generate endpoint will be used by the frontend to handle incoming requests.
@endpoint() @endpoint()
async def generate(self, raw_request: MultiModalRequest): async def generate(self, raw_request: MultiModalRequest):
prompt = str(self.engine_args.prompt_template).replace( # Ensure the configured template includes the placeholder
"<prompt>", raw_request.messages[0].content[0].text template = self.engine_args.prompt_template
) if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
prompt = template.replace("<prompt>", user_text)
msg = { msg = {
"role": "user", "role": "user",
"content": prompt, "content": prompt,
...@@ -201,6 +211,7 @@ class Processor(ProcessMixIn): ...@@ -201,6 +211,7 @@ class Processor(ProcessMixIn):
messages=[msg], messages=[msg],
stream=raw_request.stream, stream=raw_request.stream,
max_tokens=raw_request.max_tokens, max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
) )
image_url = None image_url = None
......
...@@ -26,6 +26,8 @@ VllmDecodeWorker: ...@@ -26,6 +26,8 @@ VllmDecodeWorker:
enforce-eager: true enforce-eager: true
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
enable-prefix-caching: true enable-prefix-caching: true
image-token-id: 32000
num-patches: 576
router: random router: random
tensor-parallel-size: 1 tensor-parallel-size: 1
ServiceArgs: ServiceArgs:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: microsoft/Phi-3.5-vision-instruct
block-size: 64
max-model-len: 4096
trust-remote-code: true
Processor:
router: round-robin
prompt-template: "<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
common-configs: [model, block-size, max-model-len, trust-remote-code]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
max-num-seqs: 2
mm-processor-kwargs:
num_crops: 16
enable-prefix-caching: true
image-token-id: 32000
num-patches: 757
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, trust-remote-code]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: Qwen/Qwen2.5-VL-7B-Instruct
block-size: 64
max-model-len: 4096
Processor:
router: round-robin
prompt-template: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
common-configs: [model, block-size, max-model-len]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
max-num-seqs: 5
mm-processor-kwargs:
min_pixels: 784
max_pixels: 1003520
fps: 1
enable-prefix-caching: true
image-token-id: 151655
num-patches: 345
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
...@@ -16,6 +16,8 @@ Common: ...@@ -16,6 +16,8 @@ Common:
model: llava-hf/llava-1.5-7b-hf model: llava-hf/llava-1.5-7b-hf
block-size: 64 block-size: 64
max-model-len: 4096 max-model-len: 4096
image-token-id: 32000
num-patches: 576
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}' kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
Processor: Processor:
...@@ -32,7 +34,7 @@ VllmDecodeWorker: ...@@ -32,7 +34,7 @@ VllmDecodeWorker:
workers: 1 workers: 1
resources: resources:
gpu: '1' gpu: '1'
common-configs: [model, block-size, max-model-len, kv-transfer-config] common-configs: [model, block-size, image-token-id, max-model-len, num-patches, kv-transfer-config]
VllmPrefillWorker: VllmPrefillWorker:
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
...@@ -40,7 +42,7 @@ VllmPrefillWorker: ...@@ -40,7 +42,7 @@ VllmPrefillWorker:
workers: 1 workers: 1
resources: resources:
gpu: '1' gpu: '1'
common-configs: [model, block-size, max-model-len, kv-transfer-config] common-configs: [model, block-size, image-token-id, max-model-len, num-patches, kv-transfer-config]
VllmEncodeWorker: VllmEncodeWorker:
tensor-parallel-size: 1 tensor-parallel-size: 1
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Tuple
import torch
from transformers import AutoConfig
from utils.protocol import EncodeResponse
from vllm import AsyncEngineArgs
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.worker import Worker
logger = logging.getLogger(__name__)
def load_vision_model(model_id: str) -> torch.nn.Module:
"""
Load a vision model from a HuggingFace model ID.
"""
engine_args = AsyncEngineArgs(model=model_id, trust_remote_code=True)
engine_config = engine_args.create_engine_config()
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
worker = Worker(
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
)
# Initialize the worker.
worker.init_device()
worker.load_model()
return worker.model_runner.model
def get_vision_embeddings_info(
model_id: str, num_patches: int
) -> Tuple[Tuple[int, int, int], torch.dtype]:
"""Calculate vision embeddings size and dtype using model config
Returns a tuple of (batch_size, num_patches, hidden_dim), dtype.
"""
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
assert num_patches > 0, "Number of patches must be positive"
if not hasattr(config, "torch_dtype"):
raise ValueError("Model config missing required 'torch_dtype' attribute")
if not hasattr(config, "hidden_size"):
logger.warning(
"Model config missing required 'hidden_size' attribute, using 4096"
)
hidden_size = 4096
else:
hidden_size = config.hidden_size
return (1, num_patches, hidden_size), config.torch_dtype
def construct_mm_data(
model: str,
encode_output: EncodeResponse,
image_embeds: torch.Tensor,
embeddings_dtype: torch.dtype,
) -> Dict[str, torch.Tensor | Dict[str, Any]]:
"""Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
image_embeds = image_embeds.to(embeddings_dtype)
if "Qwen2" in model:
return {
"image": {
"image_embeds": image_embeds.squeeze(0),
"image_grid_thw": torch.tensor(encode_output.image_grid_thw).squeeze(0),
}
}
elif "MiniCPM-V" in model:
return {
"image": {
"image_embeds": image_embeds,
"image_sizes": encode_output.image_sizes,
}
}
else:
return {"image": image_embeds}
...@@ -119,6 +119,7 @@ class MultiModalRequest(BaseModel): ...@@ -119,6 +119,7 @@ class MultiModalRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: List[ChatMessage]
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
temperature: Optional[float] = None
stream: Optional[bool] = True stream: Optional[bool] = True
...@@ -141,6 +142,8 @@ class EncodeRequest(BaseModel): ...@@ -141,6 +142,8 @@ class EncodeRequest(BaseModel):
class EncodeResponse(BaseModel): class EncodeResponse(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str request_id: str
image_grid_thw: Optional[List[Any]] = None
image_sizes: Optional[List[Any]] = None
class MyRequestOutput(BaseModel): class MyRequestOutput(BaseModel):
......
...@@ -51,6 +51,18 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: ...@@ -51,6 +51,18 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
default=3, default=3,
help="Maximum queue size for remote prefill. If the prefill queue size is greater than this value, prefill phase of the incoming request will be executed locally.", help="Maximum queue size for remote prefill. If the prefill queue size is greater than this value, prefill phase of the incoming request will be executed locally.",
) )
parser.add_argument(
"--image-token-id",
type=int,
default=32000,
help="Image token ID used to represent image patches in the token sequence",
)
parser.add_argument(
"--num-patches",
type=int,
default=576,
help="Number of patches the input image is divided into (must be positive)",
)
parser.add_argument( parser.add_argument(
"--prompt-template", "--prompt-template",
type=str, type=str,
...@@ -66,4 +78,6 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: ...@@ -66,4 +78,6 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
engine_args.max_local_prefill_length = args.max_local_prefill_length engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_prefill_queue_size = args.max_prefill_queue_size engine_args.max_prefill_queue_size = args.max_prefill_queue_size
engine_args.prompt_template = args.prompt_template engine_args.prompt_template = args.prompt_template
engine_args.num_patches = args.num_patches
engine_args.image_token_id = args.image_token_id
return engine_args return engine_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment