Unverified Commit e924a7c7 authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

feat: generalize VLM embedding extraction (#1388)


Signed-off-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
Co-authored-by: default avatarKris Hung <krish@nvidia.com>
parent c43ebd24
......@@ -18,7 +18,6 @@ limitations under the License.
# Multimodal Deployment Examples
This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo.
The examples are based on the [llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model.
## Use the Latest Release
......@@ -59,11 +58,15 @@ flowchart LR
decode_worker --image_url--> encode_worker
encode_worker --embeddings--> decode_worker
```
```
```bash
cd $DYNAMO_HOME/examples/multimodal
dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
# Serve a LLaVA 1.5 7B model:
dynamo serve graphs.agg:Frontend -f ./configs/agg-llava.yaml
# Serve a Qwen2.5-VL model:
# dynamo serve graphs.agg:Frontend -f ./configs/agg-qwen.yaml
# Serve a Phi3V model:
# dynamo serve graphs.agg:Frontend -f ./configs/agg-phi3v.yaml
```
### Client
......@@ -92,10 +95,13 @@ curl http://localhost:8000/v1/chat/completions \
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
You should see a response similar to this:
```json
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
......@@ -162,6 +168,7 @@ curl http://localhost:8000/v1/chat/completions \
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
......@@ -171,6 +178,8 @@ You should see a response similar to this:
{"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]}
```
***Note***: disaggregation is currently only confirmed to work with LLaVA. Qwen VL and PhiV are not confirmed to be supported.
## Deployment with Dynamo Operator
These multimodal examples can be deployed to a Kubernetes cluster using [Dynamo Cloud](../../docs/guides/dynamo_deploy/dynamo_cloud.md) and the Dynamo CLI.
......@@ -206,8 +215,12 @@ DYNAMO_TAG=$(dynamo build graphs.agg:Frontend | grep "Successfully built" | awk
# Deploy to Kubernetes
export DEPLOYMENT_NAME=multimodal-agg
# For aggregated serving:
dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg.yaml
# For aggregated serving with LLaVA:
dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-llava.yaml
# For aggregated serving with Qwen2.5-VL:
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-qwen.yaml
# For aggregated serving with Phi3V:
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-phi3v.yaml
# For disaggregated serving:
# export DEPLOYMENT_NAME=multimodal-disagg
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/disagg.yaml
......@@ -244,8 +257,11 @@ curl localhost:8000/v1/chat/completions \
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
For more details on managing deployments, testing, and troubleshooting, please refer to the [Operator Deployment Guide](../../docs/guides/dynamo_deploy/operator_deployment.md).
......@@ -24,8 +24,8 @@ import torch
from components.disagg_router import PyDisaggregatedRouter
from components.encode_worker import VllmEncodeWorker
from components.prefill_worker import VllmPrefillWorker
from transformers import LlavaForConditionalGeneration
from utils.logging import check_required_workers
from utils.model import construct_mm_data, get_vision_embeddings_info
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import (
......@@ -117,6 +117,11 @@ class VllmDecodeWorker:
)
runtime = dynamo_context["runtime"]
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
self.engine_args.model, self.engine_args.num_patches
)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]
if self.do_remote_prefill:
metadata = self.engine_client.nixl_metadata
......@@ -133,18 +138,7 @@ class VllmDecodeWorker:
await self.disaggregated_router.async_init()
else:
self.disaggregated_router = None
model = LlavaForConditionalGeneration.from_pretrained(
self.engine_args.model,
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()
vision_tower = model.vision_tower
self.embedding_size = (
vision_tower.vision_model.embeddings.position_embedding.num_embeddings
)
else:
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"
......@@ -161,7 +155,7 @@ class VllmDecodeWorker:
# Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty(
EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
......@@ -206,13 +200,15 @@ class VllmDecodeWorker:
multi_modal_data,
remote_prefill_params,
) = await self.remote_prefill(request)
else:
(
prompt_ids,
multi_modal_data,
remote_prefill_params,
) = await self.local_prefill(request)
logger.debug(f"Prompt ids: {prompt_ids}")
logger.debug(f"Multi modal data: {multi_modal_data}")
logger.debug(f"Remote prefill params: {remote_prefill_params}")
# rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA
......@@ -227,7 +223,7 @@ class VllmDecodeWorker:
remote_prefill_params=remote_prefill_params,
):
logger.debug(
f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
f"Yielding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
)
yield MyRequestOutput(
request_id=response.request_id,
......@@ -294,7 +290,9 @@ class VllmDecodeWorker:
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
multi_modal_data = {"image": embeddings}
multi_modal_data = construct_mm_data(
self.engine_args.model, encode_output, embeddings, self.embeddings_dtype
)
return prompt_ids, multi_modal_data, remote_prefill_params
......@@ -353,17 +351,16 @@ class VllmDecodeWorker:
# As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size
# so that decode worker can pre-allocate the memory with the correct size.
# The structure of the prompt will be like: "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:".
# Since the "<image>" token is included in the prompt, only need to insert (embedding_size - 1) dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
# Since the "<image>" token is included in the prompt, only need to insert embedding_size dummy tokens after the image token.
DUMMY_TOKEN_ID = 0
# Find the index of the image token in the prompt token ids
image_token_index = request.engine_prompt["prompt_token_ids"].index(
IMAGE_TOKEN_ID
self.engine_args.image_token_id
)
dummy_token_index = image_token_index + 1
prompt_ids = (
request.engine_prompt["prompt_token_ids"][:dummy_token_index]
+ [DUMMY_TOKEN_ID] * (self.embedding_size - 1)
+ [DUMMY_TOKEN_ID] * self.embedding_size
+ request.engine_prompt["prompt_token_ids"][dummy_token_index:]
)
logger.debug(
......
......@@ -26,7 +26,8 @@ import connect
import httpx
import torch
from PIL import Image
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from transformers import AutoImageProcessor
from utils.model import load_vision_model
from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args
......@@ -66,10 +67,7 @@ class VllmEncodeWorker:
self.image_processor = AutoImageProcessor.from_pretrained(
self.MODEL_ID, trust_remote_code=True
)
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
).eval()
self.vision_model = load_vision_model(self.MODEL_ID)
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
......@@ -167,17 +165,32 @@ class VllmEncodeWorker:
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
# Add a batch dimension to everything
for item in image_embeds:
image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
logger.debug(f"Image embeds: {image_embeds}")
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
image_sizes = (
image_embeds["image_sizes"].tolist()
if "image_sizes" in image_embeds
else [image.size]
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
logger.debug("Vision model completed.")
embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)
with torch.no_grad():
embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
if isinstance(embeddings, tuple) or isinstance(embeddings, list):
# The result multimodal_embeddings may be a list or tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
# TODO: for multi-image support, this result will contain multiple tensors.
embeddings = embeddings[0].unsqueeze(0)
logger.debug(
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
)
......@@ -201,6 +214,8 @@ class VllmEncodeWorker:
yield EncodeResponse(
request_id=request.request_id,
image_grid_thw=image_grid_thw,
image_sizes=image_sizes,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
......
......@@ -25,6 +25,7 @@ import torch
from components.encode_worker import VllmEncodeWorker
from pydantic import BaseModel
from utils.logging import check_required_workers
from utils.model import construct_mm_data, get_vision_embeddings_info
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import EncodeRequest, EncodeResponse
......@@ -39,9 +40,6 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
logger = logging.getLogger(__name__)
# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"
......@@ -113,9 +111,12 @@ class VllmPrefillWorker:
await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
self.engine_args.model, self.engine_args.num_patches
)
embeddings = torch.empty(
EMBEDDINGS_SHAPE,
dtype=EMBEDDINGS_DTYPE,
embeddings_shape,
dtype=self.embeddings_dtype,
device=EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
......@@ -248,10 +249,11 @@ class VllmPrefillWorker:
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens are inserted based on the embedding size in the worker.py.
# TODO: make this more flexible/model-dependent
IMAGE_TOKEN_ID = 32000
embedding_size = embeddings.shape[1]
padding_size = embedding_size - 1
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
padding_size = embedding_size
image_token_index = request.prompt_token_ids.index(
self.engine_args.image_token_id
)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
......@@ -262,7 +264,12 @@ class VllmPrefillWorker:
request_id=request_id,
prompt=TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"image": embeddings},
multi_modal_data=construct_mm_data(
self.engine_args.model,
encode_output,
embeddings,
self.embeddings_dtype,
),
),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
......
......@@ -188,9 +188,19 @@ class Processor(ProcessMixIn):
# The generate endpoint will be used by the frontend to handle incoming requests.
@endpoint()
async def generate(self, raw_request: MultiModalRequest):
prompt = str(self.engine_args.prompt_template).replace(
"<prompt>", raw_request.messages[0].content[0].text
)
# Ensure the configured template includes the placeholder
template = self.engine_args.prompt_template
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
prompt = template.replace("<prompt>", user_text)
msg = {
"role": "user",
"content": prompt,
......@@ -201,6 +211,7 @@ class Processor(ProcessMixIn):
messages=[msg],
stream=raw_request.stream,
max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
image_url = None
......
......@@ -26,6 +26,8 @@ VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
image-token-id: 32000
num-patches: 576
router: random
tensor-parallel-size: 1
ServiceArgs:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: microsoft/Phi-3.5-vision-instruct
block-size: 64
max-model-len: 4096
trust-remote-code: true
Processor:
router: round-robin
prompt-template: "<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
common-configs: [model, block-size, max-model-len, trust-remote-code]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
max-num-seqs: 2
mm-processor-kwargs:
num_crops: 16
enable-prefix-caching: true
image-token-id: 32000
num-patches: 757
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, trust-remote-code]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: Qwen/Qwen2.5-VL-7B-Instruct
block-size: 64
max-model-len: 4096
Processor:
router: round-robin
prompt-template: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
common-configs: [model, block-size, max-model-len]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
max-num-seqs: 5
mm-processor-kwargs:
min_pixels: 784
max_pixels: 1003520
fps: 1
enable-prefix-caching: true
image-token-id: 151655
num-patches: 345
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
......@@ -16,6 +16,8 @@ Common:
model: llava-hf/llava-1.5-7b-hf
block-size: 64
max-model-len: 4096
image-token-id: 32000
num-patches: 576
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
Processor:
......@@ -32,7 +34,7 @@ VllmDecodeWorker:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, kv-transfer-config]
common-configs: [model, block-size, image-token-id, max-model-len, num-patches, kv-transfer-config]
VllmPrefillWorker:
max-num-batched-tokens: 16384
......@@ -40,7 +42,7 @@ VllmPrefillWorker:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, kv-transfer-config]
common-configs: [model, block-size, image-token-id, max-model-len, num-patches, kv-transfer-config]
VllmEncodeWorker:
tensor-parallel-size: 1
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Tuple
import torch
from transformers import AutoConfig
from utils.protocol import EncodeResponse
from vllm import AsyncEngineArgs
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.worker import Worker
logger = logging.getLogger(__name__)
def load_vision_model(model_id: str) -> torch.nn.Module:
"""
Load a vision model from a HuggingFace model ID.
"""
engine_args = AsyncEngineArgs(model=model_id, trust_remote_code=True)
engine_config = engine_args.create_engine_config()
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
worker = Worker(
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
)
# Initialize the worker.
worker.init_device()
worker.load_model()
return worker.model_runner.model
def get_vision_embeddings_info(
model_id: str, num_patches: int
) -> Tuple[Tuple[int, int, int], torch.dtype]:
"""Calculate vision embeddings size and dtype using model config
Returns a tuple of (batch_size, num_patches, hidden_dim), dtype.
"""
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
assert num_patches > 0, "Number of patches must be positive"
if not hasattr(config, "torch_dtype"):
raise ValueError("Model config missing required 'torch_dtype' attribute")
if not hasattr(config, "hidden_size"):
logger.warning(
"Model config missing required 'hidden_size' attribute, using 4096"
)
hidden_size = 4096
else:
hidden_size = config.hidden_size
return (1, num_patches, hidden_size), config.torch_dtype
def construct_mm_data(
model: str,
encode_output: EncodeResponse,
image_embeds: torch.Tensor,
embeddings_dtype: torch.dtype,
) -> Dict[str, torch.Tensor | Dict[str, Any]]:
"""Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
image_embeds = image_embeds.to(embeddings_dtype)
if "Qwen2" in model:
return {
"image": {
"image_embeds": image_embeds.squeeze(0),
"image_grid_thw": torch.tensor(encode_output.image_grid_thw).squeeze(0),
}
}
elif "MiniCPM-V" in model:
return {
"image": {
"image_embeds": image_embeds,
"image_sizes": encode_output.image_sizes,
}
}
else:
return {"image": image_embeds}
......@@ -119,6 +119,7 @@ class MultiModalRequest(BaseModel):
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
stream: Optional[bool] = True
......@@ -141,6 +142,8 @@ class EncodeRequest(BaseModel):
class EncodeResponse(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
image_grid_thw: Optional[List[Any]] = None
image_sizes: Optional[List[Any]] = None
class MyRequestOutput(BaseModel):
......
......@@ -51,6 +51,18 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
default=3,
help="Maximum queue size for remote prefill. If the prefill queue size is greater than this value, prefill phase of the incoming request will be executed locally.",
)
parser.add_argument(
"--image-token-id",
type=int,
default=32000,
help="Image token ID used to represent image patches in the token sequence",
)
parser.add_argument(
"--num-patches",
type=int,
default=576,
help="Number of patches the input image is divided into (must be positive)",
)
parser.add_argument(
"--prompt-template",
type=str,
......@@ -66,4 +78,6 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_prefill_queue_size = args.max_prefill_queue_size
engine_args.prompt_template = args.prompt_template
engine_args.num_patches = args.num_patches
engine_args.image_token_id = args.image_token_id
return engine_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment