Unverified Commit 9831aec4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Dynamic image size support for VLMs (#5276)


Signed-off-by: default avatarXiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: default avatarXiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: default avatarywang96 <ywang@roblox.com>
Co-authored-by: default avatarxwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
parent 482045ee
......@@ -6,7 +6,7 @@ from transformers import CLIPVisionConfig, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
......@@ -20,8 +20,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
......@@ -51,28 +53,10 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
if mask.sum() != image_feature_size:
raise ValueError(f"image_feature_size should be {image_feature_size}, "
f"but found: {mask.sum()}")
inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
vision_embeddings.shape[-1])
return inputs_embeds
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
"""Shape: `(batch_size, num_channels, height, width)`"""
LlavaImageInputs = LlavaImagePixelInputs
......@@ -96,8 +80,30 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
raise NotImplementedError(msg)
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
......@@ -112,7 +118,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
......
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
......@@ -10,7 +10,7 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
......@@ -21,13 +21,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_patch_grid_length)
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
from .llava import LlavaMultiModalProjector
from .utils import merge_vision_embeddings
logger = init_logger(__name__)
......@@ -39,16 +40,27 @@ _KEYS_TO_MODIFY_MAPPING = {
class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
data: BatchedTensors
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
"""
image_sizes: NotRequired[torch.Tensor]
"""Shape: (batch_size, 2)"""
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
LlavaNextImageInputs = LlavaNextImagePixelInputs
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
# NOTE: new_height and new_width are further incremented to properly invert the
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
def _get_llava_next_num_unpadded_features(
height: int,
width: int,
......@@ -56,7 +68,6 @@ def _get_llava_next_num_unpadded_features(
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
# Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
......@@ -64,9 +75,13 @@ def _get_llava_next_num_unpadded_features(
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
if new_height % 2 == 1:
new_height += 1
current_height = new_height
else:
new_width = (width * current_height) // height
if new_width % 2 == 1:
new_width += 1
current_width = new_width
unpadded_features = current_height * current_width
......@@ -74,7 +89,8 @@ def _get_llava_next_num_unpadded_features(
return (unpadded_features, newline_features)
def _get_llava_next_image_feature_size(
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
def get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
input_height: int,
......@@ -89,7 +105,9 @@ def _get_llava_next_image_feature_size(
)
base_feature_size = num_patches * num_patches
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
......@@ -110,14 +128,16 @@ def _get_llava_next_image_feature_size(
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
#TODO: change the logic for dummy data to support dynamic shape
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
image_feature_size = _get_llava_next_image_feature_size(
hf_config, input_height=dummy_height, input_width=dummy_width)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
dummy_height = dummy_width = 448
image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=dummy_height,
input_width=dummy_width,
)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
......@@ -139,27 +159,47 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
raise NotImplementedError(msg)
def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
if isinstance(image, Image.Image):
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image = image.resize((w, h))
image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
raise TypeError(f"Invalid type for 'image': {type(image)}")
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
......@@ -172,8 +212,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.vlm_config = vlm_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config=config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
......@@ -196,24 +236,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
_, num_channels, _, _ = self.vlm_config.image_input_shape
# Note that this is different from that of vLLM vision_language_config
# since the image is resized by the HuggingFace preprocessor
height = width = self.config.vision_config.image_size
if list(data.shape[2:]) != [num_channels, height, width]:
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"num_patches plus {[num_channels, height, width]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
return data
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
......@@ -223,14 +245,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
if pixel_values is None or image_sizes is None:
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
......@@ -240,7 +262,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
data=pixel_values,
image_sizes=self._validate_image_sizes(image_sizes),
)
......@@ -267,15 +289,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
patch_embeddings: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
if strategy == "flat":
return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"):
orig_width, orig_height = image_size
height = width = self.config.vision_config.image_size \
// self.config.vision_config.patch_size
......@@ -289,13 +310,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
other_patch_embeds = patch_embeddings[1:]
# image_aspect_ratio == "anyres"
# Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
(orig_width, orig_height),
image_size,
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
other_patch_embeds = other_patch_embeds \
.view(num_patch_width, num_patch_height, height, width, -1)
.view(num_patch_height, num_patch_width, height, width, -1)
if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \
......@@ -333,44 +356,53 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
def _process_image_pixels(
self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
self,
inputs: LlavaNextImagePixelInputs,
) -> BatchedTensors:
assert self.vision_tower is not None
pixel_values = inputs["data"]
if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
stacked_patch_embeddings = self.multi_modal_projector(
stacked_image_features)
return stacked_patch_embeddings.view(
b, num_patches, *stacked_patch_embeddings.shape[1:])
num_patches_per_batch = [v.shape[0] for v in pixel_values]
stacked_pixel_values = torch.cat(pixel_values)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
return stacked_image_features.view(b, num_patches,
*stacked_image_features.shape[-2:])
return [
self.multi_modal_projector(image_features) for image_features in
torch.split(stacked_image_features, num_patches_per_batch)
]
def _process_image_input(
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
patch_embeddings = self.multi_modal_projector(image_features)
self, image_input: LlavaNextImageInputs) -> BatchedTensors:
patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes")
if image_sizes is None:
batch_size = image_input["data"].shape[0]
batch_size = len(image_input["data"])
vision_config = self.config.vision_config
default_width = default_height = vision_config.image_size
image_sizes = torch.as_tensor([[default_width, default_height]
default_height = default_width = vision_config.image_size
image_sizes = torch.as_tensor([[default_height, default_width]
for _ in range(batch_size)])
merged_patch_embeddings = [
return [
self._merge_image_patch_embeddings(image_sizes[i],
patch_features,
patch_features_batch,
strategy="spatial_unpad")
for i, patch_features in enumerate(patch_embeddings)
for i, patch_features_batch in enumerate(patch_embeddings)
]
return torch.stack(merged_patch_embeddings, dim=0)
def forward(
self,
input_ids: torch.Tensor,
......@@ -404,8 +436,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each grid patch for each input image.
Expects a batch with shape `[1, num_patches, 3, 336, 336]`.
image_sizes: The original `(width, height)` for each input image.
Expects a batch with shape `[1, num_patches, 3, h, w]`.
image_sizes: The original `(height, width)` for each input image.
Expects a batch with shape `[1, 2]`.
See also:
......
......@@ -13,7 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
import re
from functools import lru_cache
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import numpy as np
import torch
......@@ -22,8 +24,8 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
......@@ -34,10 +36,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision
logger = init_logger(__name__)
......@@ -251,50 +255,22 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
data: BatchedTensors
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
image_sizes: torch.Tensor
"""Shape: (batch_size, 2)"""
def _get_phi3v_image_feature_size(
*,
input_height: int,
input_width: int,
) -> int:
h, w = input_height, input_width
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
#TODO: change the logic for dummy data to support dynamic shape
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
image_feature_size = _get_phi3v_image_feature_size(
input_height=dummy_height,
input_width=dummy_width,
)
Note that `num_patches` may be different for each batch.
"""
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
return seq_data, mm_data
This should be in `(height, width)` format.
"""
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2)
......@@ -304,7 +280,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
return padded_width, padded_height
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
transposed = False
if width < height:
......@@ -329,27 +305,133 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
return padded_width, padded_height
def _image_processor(ctx: InputContext,
image: object) -> Dict[str, torch.Tensor]:
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def get_phi3v_image_feature_size(
hf_config: PretrainedConfig,
*,
input_height: int,
input_width: int,
) -> int:
num_crops = getattr(hf_config, "num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
if isinstance(image, Image.Image):
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != _calc_hd_transform_size(width=image.width,
height=image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
+ (new_height // 336 + 1) * 12
image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
raise TypeError(f"Invalid type for 'image': {type(image)}")
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50
image_feature_size = get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
input_height=dummy_height,
input_width=dummy_width,
)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
# Reserve this function to also handle placeholders for additional images
# [ref: PR #5820]
@lru_cache
def _get_image_placeholder_token_ids(model_config: ModelConfig,
idx: int) -> List[int]:
assert idx > 0
tokenizer = cached_get_tokenizer(model_config.tokenizer)
# We need to get the token for "<", not "▁<"
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
f"a<|image_{idx}|>", add_special_tokens=False)
assert a_token_id == a_token_id_
return image_placeholder_token_ids
def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(PretrainedConfig)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
w, h = image_data.size
w, h = _calc_hd_transform_size(width=w, height=h)
image_feature_size = get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
prompt = llm_inputs.get("prompt")
if prompt is None:
new_prompt = None
else:
if prompt.count("<|image|>") > 0:
logger.warning("Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating <|image|> tokens.")
elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
new_prompt = prompt
prompt_token_ids = llm_inputs["prompt_token_ids"]
image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
new_token_ids: List[int] = []
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
new_token_ids.append(multimodal_config.image_token_id)
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
break
else:
new_token_ids.append(prompt_token_ids[i])
# NOTE: Create a defensive copy of the original inputs
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
return input_processor_for_clip(
model_config,
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
llm_inputs,
image_token_id=multimodal_config.image_token_id,
image_feature_size_override=image_feature_size,
)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self,
......@@ -363,6 +445,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self.vlm_config = vlm_config
self.model = LlamaModel(config, cache_config, quant_config)
# TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
vlm_config, config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size,
......@@ -376,13 +460,21 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
if pixel_values is not None and image_sizes is not None:
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(type="pixel_values",
data=pixel_values,
image_sizes=image_sizes)
return None
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
......
import torch
from vllm.multimodal import BatchedTensors
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: BatchedTensors,
image_token_id: int) -> torch.Tensor:
"""
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
Note:
This updates `inputs_embeds` in place.
"""
mask = (input_ids == image_token_id)
num_expected_tokens = mask.sum()
if isinstance(vision_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
total_tokens = batch_size * batch_tokens
if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}"
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in vision_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(vision_embeddings)
return inputs_embeds
from .base import MultiModalDataDict, MultiModalPlugin
from .base import (BatchedTensors, MultiModalDataDict, MultiModalInputs,
MultiModalPlugin)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
......@@ -11,8 +12,10 @@ See also:
"""
__all__ = [
"BatchedTensors",
"MultiModalDataDict",
"MultiModalInputs",
"MultiModalPlugin",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
"MultiModalDataDict",
]
import sys
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type,
TypedDict, TypeVar, Union)
from collections import UserDict, defaultdict
from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict,
TypeVar, Union)
import torch
import torch.types
from PIL import Image
from torch import nn
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
if TYPE_CHECKING:
logger = init_logger(__name__)
BatchedTensors = Union[torch.Tensor, List[torch.Tensor]]
"""
If each input tensor in the batch has the same size, this is a single batched
tensor; otherwise, this is a list of tensors with one element per batch.
"""
if sys.version_info < (3, 9):
# UserDict cannot be subscripted
class _MultiModalInputsBase(UserDict):
pass
else:
class _MultiModalInputsBase(UserDict[str, torch.Tensor]):
pass
class MultiModalInputs(_MultiModalInputsBase):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
"""
@staticmethod
def try_concat(
tensors: List[torch.Tensor],
*,
device: torch.types.Device,
) -> BatchedTensors:
# Avoid initializing CUDA too early
import torch
from PIL import Image
from torch import nn
logger = init_logger(__name__)
unbatched_shape = tensors[0].shape[1:]
for tensor in tensors:
if tensor.shape[1:] != unbatched_shape:
return [
tensor.squeeze(0).to(device=device) for tensor in tensors
]
return torch.cat(tensors, dim=0).to(device=device)
@staticmethod
def batch(
inputs_list: List["MultiModalInputs"],
device: torch.types.Device,
) -> Dict[str, BatchedTensors]:
"""Batch multiple inputs together into a dictionary."""
if len(inputs_list) == 0:
return {}
keys = inputs_list[0].keys()
N = TypeVar("N", bound=Type["nn.Module"])
item_lists: Dict[str, List[torch.Tensor]] = defaultdict(list)
for inputs in inputs_list:
if inputs.keys() != keys:
msg = f"Inputs do not share the same keys ({keys})"
raise ValueError(msg)
for k, v in inputs.items():
item_lists[k].append(v)
return {
k: MultiModalInputs.try_concat(item_list, device=device)
for k, item_list in item_lists.items()
}
class MultiModalDataBuiltins(TypedDict, total=False):
image: "Image.Image"
image: Image.Image
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
......@@ -29,12 +96,13 @@ to the model by the corresponding mapper. By default, the mapper of
the corresponding plugin with the same modality key is applied.
"""
MultiModalInputMapper = Callable[[InputContext, object], Dict[str,
"torch.Tensor"]]
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
"""Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""
N = TypeVar("N", bound=Type[nn.Module])
class MultiModalPlugin(ABC):
"""
......@@ -48,8 +116,7 @@ class MultiModalPlugin(ABC):
"""
def __init__(self) -> None:
self._input_mappers: Dict[Type["nn.Module"],
MultiModalInputMapper] = {}
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
@abstractmethod
def get_data_key(self) -> str:
......@@ -60,7 +127,7 @@ class MultiModalPlugin(ABC):
@abstractmethod
def _default_input_mapper(self, ctx: InputContext,
data: object) -> Dict[str, "torch.Tensor"]:
data: object) -> MultiModalInputs:
"""Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
......@@ -80,6 +147,7 @@ class MultiModalPlugin(ABC):
See also:
:ref:`input_processing_pipeline`
:ref:`adding_a_new_multimodal_model`
"""
def wrapper(model_cls: N) -> N:
......@@ -97,7 +165,7 @@ class MultiModalPlugin(ABC):
return wrapper
def map_input(self, model_config: ModelConfig,
data: object) -> Dict[str, "torch.Tensor"]:
data: object) -> MultiModalInputs:
"""
Apply an input mapper to a data passed
to the model, transforming the data into a dictionary of model inputs.
......@@ -106,7 +174,8 @@ class MultiModalPlugin(ABC):
The model is identified by ``model_config``.
TODO: Add guide [ref: PR #5276]
See also:
:ref:`adding_a_new_multimodal_model`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
......
from functools import lru_cache
from typing import Dict
from typing import List, Optional, Tuple, TypeVar
import torch
from PIL import Image
from transformers import PreTrainedTokenizerBase
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from .base import MultiModalPlugin
from .base import MultiModalInputs, MultiModalPlugin
logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor)
cached_get_tokenizer = lru_cache(get_tokenizer)
# Utilities for image input processors
_T = TypeVar("_T", str, int)
def repeat_and_pad_token(
token: _T,
*,
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> List[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
if pad_token_right is not None:
replacement = replacement + [pad_token_right]
return replacement
def repeat_and_pad_image_tokens(
tokenizer: PreTrainedTokenizerBase,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
image_token_id: int,
repeat_count: int = 1,
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
if prompt is None:
new_prompt = None
else:
image_token_str = tokenizer.decode(image_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
replacement_str = "".join(
repeat_and_pad_token(
image_token_str,
repeat_count=repeat_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
image_token_count = prompt.count(image_token_str)
# This is an arbitrary number to distinguish between the two cases
if image_token_count > 16:
logger.warning(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", image_token_str)
elif image_token_count > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
# The image tokens are removed to be consistent with HuggingFace
new_prompt = prompt.replace(image_token_str, replacement_str, 1)
new_token_ids: List[int] = []
for i, token in enumerate(prompt_token_ids):
if token == image_token_id:
replacement_ids = repeat_and_pad_token(
image_token_id,
repeat_count=repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)
return new_prompt, new_token_ids
class ImagePlugin(MultiModalPlugin):
......@@ -27,7 +110,7 @@ class ImagePlugin(MultiModalPlugin):
trust_remote_code=model_config.trust_remote_code)
def _default_input_mapper(self, ctx: InputContext,
data: object) -> Dict[str, torch.Tensor]:
data: object) -> MultiModalInputs:
model_config = ctx.model_config
if isinstance(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)
......@@ -35,10 +118,15 @@ class ImagePlugin(MultiModalPlugin):
raise RuntimeError("No HuggingFace processor is available"
"to process the image object")
try:
return image_processor.preprocess(data, return_tensors="pt") \
.to(model_config.dtype).data
batch_data = image_processor \
.preprocess(data, return_tensors="pt") \
.data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
raise TypeError(f"Invalid type for 'image': {type(data)}")
return MultiModalInputs(batch_data)
elif isinstance(data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
raise TypeError(f"Invalid image type: {type(data)}")
import functools
from typing import Optional, Sequence, Type, TypeVar
from typing import Dict, Optional, Sequence
from torch import nn
import torch
from vllm.config import ModelConfig
from vllm.logger import init_logger
from .base import MultiModalDataDict, MultiModalInputMapper, MultiModalPlugin
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
MultiModalPlugin)
from .image import ImagePlugin
logger = init_logger(__name__)
N = TypeVar("N", bound=Type[nn.Module])
class MultiModalRegistry:
"""
......@@ -61,7 +60,7 @@ class MultiModalRegistry:
return self.register_input_mapper("image", mapper)
def _process_input(self, key: str, value: object,
model_config: ModelConfig):
model_config: ModelConfig) -> MultiModalInputs:
plugin = self._plugins.get(key)
if plugin:
return plugin.map_input(model_config, value)
......@@ -93,16 +92,28 @@ class MultiModalRegistry:
"""
return self.register_input_mapper("image", mapper)
def map_input(self, model_config: ModelConfig, data: MultiModalDataDict):
def map_input(self, model_config: ModelConfig,
data: MultiModalDataDict) -> MultiModalInputs:
"""
Apply an input mapper to the data passed to the model.
See :meth:`MultiModalPlugin.map_input` for more details.
"""
result_list = [
self._process_input(k, v, model_config) for k, v in data.items()
]
return {k: v for d in result_list for k, v in d.items()}
merged_dict: Dict[str, torch.Tensor] = {}
for data_key, data_value in data.items():
input_dict = self._process_input(data_key, data_value,
model_config)
for input_key, input_tensor in input_dict.items():
if input_key in merged_dict:
raise ValueError(f"The input mappers (keys={set(data)}) "
f"resulted in a conflicting keyword "
f"argument to `forward()`: {input_key}")
merged_dict[input_key] = input_tensor
return MultiModalInputs(merged_dict)
def create_input_mapper(self, model_config: ModelConfig):
"""
......
......@@ -4,11 +4,56 @@ from typing import Optional, Union
from urllib.parse import urlparse
import aiohttp
import requests
from PIL import Image
from vllm.config import ModelConfig
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.base import MultiModalDataDict
from vllm.version import __version__ as VLLM_VERSION
def _validate_remote_url(url: str, *, name: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise ValueError(f"Invalid '{name}': A valid '{name}' "
"must have scheme 'http' or 'https'.")
def _get_request_headers():
return {"User-Agent": f"vLLM/{VLLM_VERSION}"}
def _load_image_from_bytes(b: bytes):
image = Image.open(BytesIO(b))
image.load()
return image
def _load_image_from_data_url(image_url: str):
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)
def fetch_image(image_url: str) -> Image.Image:
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")
headers = _get_request_headers()
with requests.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = response.content
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
return image
class ImageFetchAiohttp:
......@@ -29,34 +74,31 @@ class ImageFetchAiohttp:
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
if image_url.startswith('http'):
parsed_url = urlparse(image_url)
if parsed_url.scheme not in ["http", "https"]:
raise ValueError("Invalid 'image_url': A valid 'image_url' "
"must have scheme 'http' or 'https'.")
# Avoid circular import
from vllm import __version__ as VLLM_VERSION
_validate_remote_url(image_url, name="image_url")
client = cls.get_aiohttp_client()
headers = {"User-Agent": f"vLLM/{VLLM_VERSION}"}
headers = _get_request_headers()
async with client.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = await response.read()
image = Image.open(BytesIO(image_raw))
image = _load_image_from_bytes(image_raw)
# Only split once and assume the second part is the base64 encoded image
elif image_url.startswith('data:image'):
image = load_image_from_base64(image_url.split(',', 1)[1])
image = _load_image_from_data_url(image_url)
else:
raise ValueError(
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
image.load()
return image
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
return {"image": image}
def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
"""Encode a pillow image to base64 format."""
......@@ -69,26 +111,11 @@ def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
"""Load image from base64 format."""
return Image.open(BytesIO(base64.b64decode(image)))
return _load_image_from_bytes(base64.b64decode(image))
# TODO(ywang96): move this to a model registry for preprocessing vision
# language prompts based on the model type.
def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
config: ModelConfig) -> str:
"""Combine image and text prompts for vision language model depending on
the model architecture."""
if config.hf_config.model_type in ("llava", "llava_next"):
full_prompt = f"{image_prompt}\n{text_prompt}"
elif config.hf_config.model_type == 'phi3_v':
full_prompt = f"{image_prompt}<s>\n{text_prompt}"
else:
raise ValueError(
f"Unsupported model type: {config.hf_config.model_type}")
return full_prompt
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
return {"image": image}
def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
"""Rescale the dimensions of an image by a constant factor."""
new_width = int(image.width * size_factor)
new_height = int(image.height * size_factor)
return image.resize((new_width, new_height))
......@@ -457,7 +457,7 @@ class SequenceGroup:
return next(iter(self.seqs_dict.values())).prompt_token_ids
@property
def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
......
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
from vllm.logger import init_logger
logger = init_logger(__name__)
from typing import cast
def get_image_processor(
......@@ -11,10 +6,15 @@ def get_image_processor(
*args,
trust_remote_code: bool = False,
**kwargs,
) -> BaseImageProcessor:
):
"""Gets an image processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
try:
processor: BaseImageProcessor = AutoImageProcessor.from_pretrained(
processor = AutoImageProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
......@@ -34,4 +34,4 @@ def get_image_processor(
else:
raise e
return processor
return cast(BaseImageProcessor, processor)
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)
import torch
from torch import nn
......@@ -12,7 +12,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad
......@@ -40,7 +41,7 @@ class CPUModelInput(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
......@@ -132,15 +133,14 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[
str, torch.Tensor]]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_kwargs_list: Dict[str,
List[torch.Tensor]] = defaultdict(list)
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
......@@ -162,10 +162,9 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
input_positions.extend(list(range(computed_len, seq_len)))
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
multi_modal_inputs_list.append(mm_kwargs)
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
......@@ -189,11 +188,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
multi_modal_kwargs = {
k: torch.cat(v, dim=0).to(self.device)
for k, v in multi_modal_kwargs_list.items()
}
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
......@@ -217,6 +211,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
......@@ -367,10 +365,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
}
if (self.vision_language_config
and model_input.multi_modal_kwargs is not None):
execute_model_kwargs.update(model_input.multi_modal_kwargs)
hidden_states = model_executable(**execute_model_kwargs)
......
......@@ -92,10 +92,9 @@ class EmbeddingModelRunner(
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
}
if self.vision_language_config:
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
execute_model_kwargs.update({"image_input": multi_modal_kwargs})
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker.
......
......@@ -3,8 +3,8 @@ import gc
import time
import warnings
from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)
import numpy as np
import torch
......@@ -37,7 +37,8 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
......@@ -83,7 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
......@@ -356,8 +357,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
context_lens: List[int] = []
query_lens: List[int] = []
block_tables: List[List[int]] = []
multi_modal_kwargs_list: Dict[str,
List[torch.Tensor]] = defaultdict(list)
multi_modal_inputs_list: List[MultiModalInputs] = []
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
decode_only = True
num_prefills = 0
......@@ -528,8 +528,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
multi_modal_inputs_list.append(mm_kwargs)
is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables)
......@@ -746,10 +745,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
else:
lora_mapping = None
multi_modal_kwargs = {
k: torch.cat(v, dim=0).to(self.device)
for k, v in multi_modal_kwargs_list.items()
}
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
request_ids_to_seq_ids = {
seq_group_metadata.request_id:
list(seq_group_metadata.seq_data.keys())
......@@ -821,7 +818,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
assert len(seq_data.prompt_token_ids) == seq_len
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata(
request_id=str(group_id),
......
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Union)
import torch
from torch import nn
......@@ -9,6 +10,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
......@@ -29,6 +32,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
......@@ -65,6 +69,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization.
self.model: nn.Module # initialize after load_model.
......@@ -76,13 +84,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[
str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
......@@ -102,6 +112,12 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
assert len(block_table) == 1
input_block_ids.append(block_table[0])
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
......@@ -118,7 +134,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids, seq_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs)
def _prepare_decode(
self,
......@@ -184,8 +204,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids,
seq_lens) = self._prepare_prompt(seq_group_metadata_list)
(input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
......@@ -203,7 +224,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata)
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs)
@torch.inference_mode()
def execute_model(
......@@ -221,6 +243,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**(model_input.multi_modal_kwargs or {}),
)
# Compute the logits.
......
from typing import List, NamedTuple, Optional, Tuple
from typing import List, Mapping, NamedTuple, Optional, Tuple
import openvino as ov
import torch
......@@ -12,6 +12,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
logger = init_logger(__name__)
......@@ -23,7 +25,7 @@ class ModelInput(NamedTuple):
attn_metadata: Optional[OpenVINOAttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
multi_modal_input: Optional[torch.Tensor]
multi_modal_kwargs: Mapping[str, BatchedTensors]
@classmethod
def empty(cls, device):
......@@ -32,7 +34,7 @@ class ModelInput(NamedTuple):
attn_metadata=None,
seq_lens=[],
query_lens=[],
multi_modal_input=None)
multi_modal_kwargs={})
class OpenVINOModelRunner:
......@@ -78,6 +80,10 @@ class OpenVINOModelRunner:
self.block_size,
)
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
......@@ -108,6 +114,8 @@ class OpenVINOModelRunner:
seq_lens: List[int] = []
past_lens: List[int] = []
query_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
subsequence_begins: List[int] = []
block_indices: List[int] = []
block_indices_begins: List[int] = []
......@@ -160,6 +168,11 @@ class OpenVINOModelRunner:
and self.sliding_window is None
and is_prompt)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
block_table = seq_group_metadata.block_tables[seq_id]
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
......@@ -251,22 +264,24 @@ class OpenVINOModelRunner:
block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return ModelInput(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
None,
multi_modal_kwargs=multi_modal_kwargs,
)
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
SamplingMetadata, Optional[torch.Tensor], ]:
multi_modal_input = None
SamplingMetadata, Mapping[str, BatchedTensors]]:
# Prepare input tensors.
(
input_tokens,
......@@ -274,7 +289,7 @@ class OpenVINOModelRunner:
attn_metadata,
seq_lens,
query_lens,
multi_modal_input,
multi_modal_kwargs,
) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
......@@ -290,7 +305,7 @@ class OpenVINOModelRunner:
input_positions,
attn_metadata,
sampling_metadata,
multi_modal_input,
multi_modal_kwargs,
)
@torch.inference_mode()
......@@ -304,7 +319,7 @@ class OpenVINOModelRunner:
input_positions,
attn_metadata,
sampling_metadata,
multi_modal_input,
multi_modal_kwargs,
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
......@@ -313,9 +328,8 @@ class OpenVINOModelRunner:
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
**(multi_modal_kwargs or {}),
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
......
import time
from typing import List, Optional, Tuple
from typing import List, Mapping, Optional, Tuple
import numpy as np
import torch
......@@ -12,6 +12,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
......@@ -66,6 +68,10 @@ class TPUModelRunner:
False,
)
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
def load_model(self) -> None:
self.device = self.device_config.device
......@@ -193,12 +199,14 @@ class TPUModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
):
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
prompt_lens: List[int] = []
slot_mapping: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
......@@ -224,6 +232,11 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
num_prefill_tokens = sum(prompt_lens)
......@@ -261,17 +274,24 @@ class TPUModelRunner:
block_tables=None,
context_lens=None,
)
return input_tokens, input_positions, attn_metadata, prompt_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_kwargs)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
):
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
......@@ -297,6 +317,11 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
......@@ -330,7 +355,12 @@ class TPUModelRunner:
block_tables=block_tables,
context_lens=context_lens,
)
return input_tokens, input_positions, attn_metadata, input_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, input_lens,
multi_modal_kwargs)
def _prepare_sample(
self,
......@@ -483,6 +513,7 @@ class ModelWrapper(nn.Module):
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor,
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
......@@ -496,6 +527,8 @@ class ModelWrapper(nn.Module):
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
"""
......@@ -540,6 +573,7 @@ class ModelWrapper(nn.Module):
position_ids,
kv_caches,
attn_metadata,
**(multi_modal_kwargs or {}),
)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)
......
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)
import torch
import torch.nn as nn
......@@ -9,10 +10,13 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData,
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
......@@ -44,7 +48,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_input: Optional[Dict[str, torch.Tensor]] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
......@@ -116,6 +120,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self.block_size,
)
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
......@@ -156,12 +164,26 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
model_config = self.model_config
vlm_config = self.vision_language_config
if vlm_config:
max_num_seqs = min(
max_num_seqs,
int(max_num_batched_tokens / vlm_config.image_feature_size))
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
......@@ -194,7 +216,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForXPU:
multi_modal_input = None
multi_modal_kwargs = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
......@@ -202,7 +224,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
......@@ -223,6 +245,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
"input_positions": input_positions,
"selected_token_indices":
sampling_metadata.selected_token_indices,
"multi_modal_kwargs": multi_modal_kwargs,
}
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
......@@ -232,6 +255,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
......@@ -244,7 +268,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
input_positions=input_positions,
attn_metadata=attn_metadata,
sampling_metadata=sampling_metadata,
multi_modal_input=multi_modal_input)
multi_modal_kwargs=multi_modal_kwargs)
def _prepare_decode(
self,
......@@ -350,10 +374,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
}
if self.vision_language_config:
execute_model_kwargs.update(
{"image_input": model_input.multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
......@@ -376,13 +398,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Optional[torch.Tensor]]:
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
......@@ -403,9 +425,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
......@@ -435,15 +458,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
......@@ -475,5 +489,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
num_decode_tokens=0,
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input)
multi_modal_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment