Unverified Commit 50cd81f3 authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add vllm multimodal qwen aggregated support (#2694)

parent dfda6205
...@@ -59,12 +59,14 @@ flowchart LR ...@@ -59,12 +59,14 @@ flowchart LR
pd_worker --> encode_worker pd_worker --> encode_worker
``` ```
***Note*** Only the LLaVA 1.5 7B model is supported. Qwen2.5-VL and Phi3V support will be added in the future. ***Note*** Aggregated serving supports LLaVA 1.5 7B and Qwen2.5-VL-7B-Instruct today. Phi3V support will be added in the future. Disaggregated serving is currently only confirmed for LLaVA (see note below).
```bash ```bash
cd $DYNAMO_HOME/examples/multimodal cd $DYNAMO_HOME/examples/multimodal
# Serve a LLaVA 1.5 7B model: # Serve a LLaVA 1.5 7B model:
bash launch/agg.sh --model llava-hf/llava-1.5-7b-hf bash launch/agg.sh --model llava-hf/llava-1.5-7b-hf
# Serve a Qwen2.5-VL model:
bash launch/agg.sh --model Qwen/Qwen2.5-VL-7B-Instruct
``` ```
### Client ### Client
...@@ -98,6 +100,8 @@ curl http://localhost:8080/v1/chat/completions \ ...@@ -98,6 +100,8 @@ curl http://localhost:8080/v1/chat/completions \
}' }'
``` ```
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`.
You should see a response similar to this: You should see a response similar to this:
```json ```json
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]} {"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
......
...@@ -21,9 +21,8 @@ import signal ...@@ -21,9 +21,8 @@ import signal
import sys import sys
from typing import AsyncIterator, Tuple from typing import AsyncIterator, Tuple
import torch
import uvloop import uvloop
from transformers import AutoImageProcessor, LlavaForConditionalGeneration from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -33,7 +32,9 @@ from dynamo.runtime.logging import configure_dynamo_logging ...@@ -33,7 +32,9 @@ from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint from utils.args import Config, base_parse_args, parse_endpoint
from utils.encode_utils import encode_image_embeddings, get_encoder_components
from utils.image_loader import ImageLoader from utils.image_loader import ImageLoader
from utils.model import load_vision_model
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging() configure_dynamo_logging()
...@@ -70,13 +71,14 @@ class VllmEncodeWorker: ...@@ -70,13 +71,14 @@ class VllmEncodeWorker:
self.image_processor = AutoImageProcessor.from_pretrained( self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True self.model, trust_remote_code=True
) )
# self.vision_model = load_vision_model(self.model) self.vision_model = load_vision_model(self.model)
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.model, device_map="auto", torch_dtype=torch.float16
).eval()
self.min_workers = 1 self.min_workers = 1
# Get encoder components for the model
self.vision_encoder, self.projector = get_encoder_components(
self.model, self.vision_model
)
def cleanup(self): def cleanup(self):
pass pass
...@@ -108,49 +110,26 @@ class VllmEncodeWorker: ...@@ -108,49 +110,26 @@ class VllmEncodeWorker:
logger.debug(f"Processing image for request: {{ id: {request_id} }}") logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt") image_embeds = self.image_processor(images=image, return_tensors="pt")
# [gluo NOTE] The commented section is for VLM generalization support,
# will use more generic approach once utils/model.py is fixed,
# see utils/models.py for details.
# # Add a batch dimension to everything
# for item in image_embeds:
# image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
# logger.debug(f"Image embeds: {image_embeds}")
# image_grid_thw = (
# image_embeds["image_grid_thw"].tolist()
# if "image_grid_thw" in image_embeds
# else None
# )
# image_sizes = (
# image_embeds["image_sizes"].tolist()
# if "image_sizes" in image_embeds
# else [image.size]
# )
# logger.debug(
# f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
# )
# with torch.no_grad():
# embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
# if isinstance(embeddings, tuple) or isinstance(embeddings, list):
# # The result multimodal_embeddings may be a list or tuple of tensors, with each
# # tensor corresponding to a multimodal data item (image or video).
# # TODO: for multi-image support, this result will contain multiple tensors.
# embeddings = embeddings[0].unsqueeze(0)
# logger.debug(
# f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
# )
with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
logger.debug("Vision model completed.")
embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)
# Encode the image embeddings using model-specific encoder
embeddings = encode_image_embeddings(
model_name=self.model,
image_embeds=image_embeds,
vision_encoder=self.vision_encoder,
projector=self.projector,
)
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings) descriptor = connect.Descriptor(embeddings)
with self._connector.create_readable(descriptor) as readable: with self._connector.create_readable(descriptor) as readable:
......
...@@ -24,7 +24,6 @@ from typing import Tuple ...@@ -24,7 +24,6 @@ from typing import Tuple
import torch import torch
import uvloop import uvloop
from transformers import AutoImageProcessor
from vllm.distributed.kv_events import ZmqEventPublisher from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
...@@ -47,6 +46,7 @@ from utils.args import ( ...@@ -47,6 +46,7 @@ from utils.args import (
parse_endpoint, parse_endpoint,
) )
from utils.image_loader import ImageLoader from utils.image_loader import ImageLoader
from utils.model import construct_mm_data
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging() configure_dynamo_logging()
...@@ -245,37 +245,15 @@ class VllmPDWorker(VllmBaseWorker): ...@@ -245,37 +245,15 @@ class VllmPDWorker(VllmBaseWorker):
.client() .client()
) )
EMBEDDINGS_DTYPE = torch.float16 self.EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cpu" self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
parsed_namespace, _, _ = parse_endpoint(self.endpoint) parsed_namespace, _, _ = parse_endpoint(self.endpoint)
self._connector = connect.Connector() self._connector = connect.Connector()
await self._connector.initialize() await self._connector.initialize()
# embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
# self.engine_args.model, self.engine_args.num_patches
# )
# [gluo NOTE] Hardcoded for now, will use more generic approach once utils/model.py
# is fixed, see utils/models.py for details.
embeddings_shape = (1, 577, 4096)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]
embeddings = torch.empty(
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
# descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
self.image_loader = ImageLoader() self.image_loader = ImageLoader()
self.image_processor = AutoImageProcessor.from_pretrained(
self.engine_args.model, trust_remote_code=True
)
logger.info("VllmPDWorker has been initialized") logger.info("VllmPDWorker has been initialized")
...@@ -288,10 +266,18 @@ class VllmPDWorker(VllmBaseWorker): ...@@ -288,10 +266,18 @@ class VllmPDWorker(VllmBaseWorker):
request = vLLMMultimodalRequest.model_validate(request) request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
if request.image_url is None: embeddings, descriptor = None, None
# Process embeddings using the connector
embeddings, descriptor = self._embeddings_descriptor # Process embeddings using the connector
# Create a descriptor based on the embedding shape.
embeddings = torch.empty(
request.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
if request.image_url is None:
if descriptor is None: if descriptor is None:
raise RuntimeError( raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings" "Descriptor is None in PD worker - cannot process embeddings"
...@@ -301,15 +287,17 @@ class VllmPDWorker(VllmBaseWorker): ...@@ -301,15 +287,17 @@ class VllmPDWorker(VllmBaseWorker):
request.serialized_request, descriptor request.serialized_request, descriptor
) )
await read_op.wait_for_completion() await read_op.wait_for_completion()
logger.debug(f"in PD worker, image features: {embeddings}") multi_modal_data = construct_mm_data(
multi_modal_data = embeddings self.engine_args.model,
embeddings,
self.EMBEDDINGS_DTYPE,
request.image_grid_thw,
)
else: else:
# Use PIL image instead of image embeddings # Use PIL image instead of image embeddings
multi_modal_data = await self.image_loader.load_image(request.image_url) multi_modal_data = {
# multi_modal_data = self.image_processor(images=image, return_tensors="pt")["pixel_values"].to(dtype=torch.float16) "image": await self.image_loader.load_image(request.image_url)
# image input is expected to be (image_num, channel, height, width) }
# logger.info(f"Image features shape: {multi_modal_data.shape}")
# multi_modal_data = multi_modal_data.unsqueeze(0)
# Remove the image features from the request as they are not required # Remove the image features from the request as they are not required
request.image_url = None request.image_url = None
...@@ -331,7 +319,7 @@ class VllmPDWorker(VllmBaseWorker): ...@@ -331,7 +319,7 @@ class VllmPDWorker(VllmBaseWorker):
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"], prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data={"image": multi_modal_data}, multi_modal_data=multi_modal_data,
), ),
sampling_params=pd_request.sampling_params, sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id, request_id=pd_request.request_id,
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional
import torch
from .model import SupportedModels
logger = logging.getLogger(__name__)
def get_qwen_image_features(
vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor:
"""
Extract image features using Qwen-style vision encoder.
Args:
vision_encoder: The vision encoder model
image_embeds: Dictionary containing pixel values and grid information
Returns:
Processed image features tensor
Raises:
ValueError: If grid_thw is not provided for Qwen model
"""
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
grid_thw = image_embeds.get("image_grid_thw", None)
if grid_thw is not None:
grid_thw = grid_thw.to(vision_encoder.device)
logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}")
else:
raise ValueError("grid_thw is not provided")
return (
vision_encoder.get_image_features(pixel_values, grid_thw) # type: ignore
if grid_thw is not None
else vision_encoder.get_image_features(pixel_values) # type: ignore
)
def encode_image_embeddings(
model_name: str,
image_embeds: Dict[str, Any],
vision_encoder: torch.nn.Module,
projector: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
"""
Encode image embeddings using the appropriate model-specific encoder.
Args:
model_name: The model identifier
image_embeds: Dictionary containing processed image data
vision_encoder: The vision encoder module
projector: The multimodal projector (required for LLaVA-style models)
Returns:
Encoded embeddings tensor with normalized shape
Raises:
ValueError: If projector is missing for LLaVA models
NotImplementedError: If model is not supported
"""
with torch.no_grad():
# Route through the correct encoder based on model
if model_name == SupportedModels.LLAVA_1_5_7B:
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
vision_outputs = vision_encoder(pixel_values)
if projector is None:
raise ValueError(f"Projector not found for LLaVA model: {model_name}")
embeddings = projector(vision_outputs.last_hidden_state)
elif model_name == SupportedModels.QWEN_2_5_VL_7B:
embeddings = get_qwen_image_features(vision_encoder, image_embeds)
else:
raise NotImplementedError(f"Model not supported: {model_name}")
# Normalize output shape
if isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings
return embeddings
def get_encoder_components(
model_name: str, vision_model: torch.nn.Module
) -> tuple[Any, Optional[Any]]:
"""
Get the appropriate vision encoder and projector components for a given model.
Args:
model_name: The model identifier
vision_model: The loaded vision model
Returns:
Tuple of (vision_encoder, projector) where types depend on the model
Raises:
NotImplementedError: If model is not supported
"""
if model_name == SupportedModels.LLAVA_1_5_7B:
vision_encoder = vision_model.vision_tower
projector = getattr(vision_model, "multi_modal_projector", None)
return vision_encoder, projector
elif model_name == SupportedModels.QWEN_2_5_VL_7B:
vision_encoder = vision_model
projector = None
return vision_encoder, projector
else:
raise NotImplementedError(f"Model not supported: {model_name}")
...@@ -31,7 +31,9 @@ class ImageLoader: ...@@ -31,7 +31,9 @@ class ImageLoader:
def __init__(self, cache_size: int = CACHE_SIZE_MAXIMUM): def __init__(self, cache_size: int = CACHE_SIZE_MAXIMUM):
self._http_timeout = 30.0 self._http_timeout = 30.0
self._http_client = httpx.AsyncClient(timeout=self._http_timeout) self._http_client = httpx.AsyncClient(
timeout=self._http_timeout, follow_redirects=True
)
self._image_cache: dict[str, Image.Image] = {} self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size) self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
......
...@@ -14,61 +14,47 @@ ...@@ -14,61 +14,47 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from transformers import AutoConfig from transformers import AutoConfig, AutoModel
from utils.protocol import EncodeResponse
from vllm import AsyncEngineArgs
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.worker import Worker
# from transformers import AutoImageProcessor, LlavaForConditionalGeneration logger = logging.getLogger(__name__)
# from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
logger = logging.getLogger(__name__) class SupportedModels:
"""Supported multimodal model identifiers"""
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
# [gluo NOTE] in vLLM v1, Worker() usage below will results in NotImplementedError,
# must find another way to properly load the vision model given the model name (model_id).
def load_vision_model(model_id: str) -> torch.nn.Module: def load_vision_model(model_id: str) -> torch.nn.Module:
""" """
Load a vision model from a HuggingFace model ID. Load a vision model from a HuggingFace model ID.
""" """
engine_args = AsyncEngineArgs(model=model_id, trust_remote_code=True) model = AutoModel.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
engine_config = engine_args.create_engine_config()
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
worker = Worker(
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
) )
# Initialize the worker. return model
worker.init_device()
worker.load_model()
return worker.model_runner.model
# model = LlavaForConditionalGeneration.from_pretrained(
# model_id, device_map="auto", torch_dtype=torch.float16
# ).eval()
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
# model_id, torch_dtype="auto", device_map="auto"
# ).eval()
# return model
def get_vision_embeddings_info( def get_vision_embeddings_info(
model_id: str, num_patches: int model_id: str,
) -> Tuple[Tuple[int, int, int], torch.dtype]: ) -> Tuple[Tuple[int, int, int], torch.dtype]:
"""Calculate vision embeddings size and dtype using model config """Calculate vision embeddings size and dtype using model config
Returns a tuple of (batch_size, num_patches, hidden_dim), dtype. Returns a tuple of (batch_size, seq_len, hidden_dim), dtype.
""" """
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
assert num_patches > 0, "Number of patches must be positive"
if model_id == SupportedModels.LLAVA_1_5_7B:
seq_len = 577
elif model_id == SupportedModels.QWEN_2_5_VL_7B:
seq_len = 345
else:
seq_len = 0
if not hasattr(config, "torch_dtype"): if not hasattr(config, "torch_dtype"):
raise ValueError("Model config missing required 'torch_dtype' attribute") raise ValueError("Model config missing required 'torch_dtype' attribute")
if not hasattr(config, "hidden_size"): if not hasattr(config, "hidden_size"):
...@@ -78,29 +64,27 @@ def get_vision_embeddings_info( ...@@ -78,29 +64,27 @@ def get_vision_embeddings_info(
hidden_size = 4096 hidden_size = 4096
else: else:
hidden_size = config.hidden_size hidden_size = config.hidden_size
return (1, num_patches, hidden_size), config.torch_dtype return (1, seq_len, hidden_size), config.torch_dtype
def construct_mm_data( def construct_mm_data(
model: str, model: str,
encode_output: EncodeResponse,
image_embeds: torch.Tensor, image_embeds: torch.Tensor,
embeddings_dtype: torch.dtype, embeddings_dtype: torch.dtype,
image_grid_thw: Optional[List[Any]],
) -> Dict[str, torch.Tensor | Dict[str, Any]]: ) -> Dict[str, torch.Tensor | Dict[str, Any]]:
"""Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings""" """Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
image_embeds = image_embeds.to(embeddings_dtype) image_embeds = image_embeds.to(embeddings_dtype)
if "Qwen2" in model: if model == SupportedModels.QWEN_2_5_VL_7B:
if image_grid_thw is not None and len(image_grid_thw) > 0:
grid_thw_tensor = torch.tensor(image_grid_thw)
else:
raise ValueError("No image grid provided.")
return { return {
"image": { "image": {
"image_embeds": image_embeds.squeeze(0), "image_embeds": image_embeds.squeeze(0),
"image_grid_thw": torch.tensor(encode_output.image_grid_thw).squeeze(0), "image_grid_thw": grid_thw_tensor,
}
}
elif "MiniCPM-V" in model:
return {
"image": {
"image_embeds": image_embeds,
"image_sizes": encode_output.image_sizes,
} }
} }
else: else:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import json import json
from typing import Any, List, Literal, Optional, Union from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec import msgspec
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
...@@ -127,7 +127,8 @@ class MultiModalRequest(BaseModel): ...@@ -127,7 +127,8 @@ class MultiModalRequest(BaseModel):
class vLLMMultimodalRequest(vLLMGenerateRequest): class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: Optional[str] = None image_url: Optional[str] = None
# image_features: Optional[List[List[List[float]]]] = None # Remove once have NIXL support image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[Tuple[int, int, int]] = None
serialized_request: Optional[connect.RdmaMetadata] = None serialized_request: Optional[connect.RdmaMetadata] = None
...@@ -142,15 +143,6 @@ class EncodeRequest(BaseModel): ...@@ -142,15 +143,6 @@ class EncodeRequest(BaseModel):
serialized_request: Optional[connect.RdmaMetadata] = None serialized_request: Optional[connect.RdmaMetadata] = None
class EncodeResponse(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
image_grid_thw: Optional[List[Any]] = None
image_sizes: Optional[List[Any]] = None
serialized_request: Optional[connect.RdmaMetadata] = None
image_features: List[List[List[float]]] # Remove once have NIXL support
class MyRequestOutput(BaseModel): class MyRequestOutput(BaseModel):
""" """
RequestOutput from vLLM is not serializable by default RequestOutput from vLLM is not serializable by default
......
...@@ -166,8 +166,8 @@ vllm_configs = { ...@@ -166,8 +166,8 @@ vllm_configs = {
], ],
timeout=560, timeout=560,
), ),
"multimodal_agg": VLLMConfig( "multimodal_agg_llava": VLLMConfig(
name="multimodal_agg", name="multimodal_agg_llava",
directory="/workspace/examples/multimodal", directory="/workspace/examples/multimodal",
script_name="agg.sh", script_name="agg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm], marks=[pytest.mark.gpu_2, pytest.mark.vllm],
...@@ -180,6 +180,20 @@ vllm_configs = { ...@@ -180,6 +180,20 @@ vllm_configs = {
args=["--model", "llava-hf/llava-1.5-7b-hf"], args=["--model", "llava-hf/llava-1.5-7b-hf"],
timeout=360, timeout=360,
), ),
"multimodal_agg_qwen": VLLMConfig(
name="multimodal_agg_qwen",
directory="/workspace/examples/multimodal",
script_name="agg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions"],
response_handlers=[
chat_completions_response_handler,
],
model="Qwen/Qwen2.5-VL-7B-Instruct",
delayed_start=0,
args=["--model", "Qwen/Qwen2.5-VL-7B-Instruct"],
timeout=360,
),
# TODO: Enable this test case when we have 4 GPUs runners. # TODO: Enable this test case when we have 4 GPUs runners.
# "multimodal_disagg": VLLMConfig( # "multimodal_disagg": VLLMConfig(
# name="multimodal_disagg", # name="multimodal_disagg",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment