Unverified Commit 74704d45 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Use merge_by_field_config for MM models (O-P) (#26776)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d2f816d6
...@@ -56,7 +56,6 @@ from vllm.multimodal.processing import ( ...@@ -56,7 +56,6 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
...@@ -70,7 +69,6 @@ from .utils import ( ...@@ -70,7 +69,6 @@ from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
...@@ -564,6 +562,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -564,6 +562,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
dummy_inputs=Phi3VDummyInputsBuilder, dummy_inputs=Phi3VDummyInputsBuilder,
) )
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens", "model.vision_embed_tokens.wte": "embed_tokens",
...@@ -631,8 +631,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -631,8 +631,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if pixel_values is not None: if pixel_values is not None:
return Phi3VImagePixelInputs( return Phi3VImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=flatten_bn(pixel_values), pixel_values=pixel_values,
image_sizes=flatten_bn(image_sizes, concat=True), image_sizes=image_sizes,
resolve_bindings={ resolve_bindings={
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
"w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
...@@ -642,7 +642,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -642,7 +642,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if image_embeds is not None: if image_embeds is not None:
return Phi3VImageEmbeddingInputs( return Phi3VImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=image_embeds,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -652,19 +652,10 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -652,19 +652,10 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
image_input: Phi3VImageInputs, image_input: Phi3VImageInputs,
) -> torch.Tensor: ) -> torch.Tensor:
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
image_data = image_input["data"] return image_input["data"]
if is_list_of(image_data, torch.Tensor):
# it's already a list of tensors
return image_data
if len(image_data.shape) == 3:
# 3D tensor
return list(torch.unbind(image_data, dim=0))
raise ValueError(
"We expect batched 2D tensors; "
"this can be either a list of 2D tensors or a single 3D tensor."
)
assert self.vision_embed_tokens is not None assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens( image_embeds = self.vision_embed_tokens(
image_input["pixel_values"], image_input["image_sizes"] image_input["pixel_values"], image_input["image_sizes"]
) )
......
...@@ -64,7 +64,6 @@ from vllm.multimodal.processing import ( ...@@ -64,7 +64,6 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
...@@ -72,7 +71,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal ...@@ -72,7 +71,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
...@@ -672,7 +670,7 @@ class Phi4MMImagePixelInputs(TensorSchema): ...@@ -672,7 +670,7 @@ class Phi4MMImagePixelInputs(TensorSchema):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: Annotated[ pixel_values: Annotated[
torch.Tensor | list[torch.Tensor], torch.Tensor | list[torch.Tensor],
TensorShape( TensorShape(
"bn", "p", 3, "h", "w", dynamic_dims={"p"} "bn", "p", 3, "h", "w", dynamic_dims={"p"}
...@@ -721,7 +719,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema): ...@@ -721,7 +719,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema):
type: Literal["audio_features"] type: Literal["audio_features"]
data: Annotated[ audio_features: Annotated[
torch.Tensor | list[torch.Tensor], torch.Tensor | list[torch.Tensor],
TensorShape("bn", "t", 80, dynamic_dims={"t"}), TensorShape("bn", "t", 80, dynamic_dims={"t"}),
] ]
...@@ -1189,6 +1187,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1189,6 +1187,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
Implements the Phi-4-multimodal-instruct model in vLLM. Implements the Phi-4-multimodal-instruct model in vLLM.
""" """
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"qkv_proj", "qkv_proj",
...@@ -1273,7 +1273,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1273,7 +1273,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if audio_features is not None: if audio_features is not None:
return Phi4MMAudioFeatureInputs( return Phi4MMAudioFeatureInputs(
type="audio_features", data=flatten_bn(audio_features) type="audio_features",
audio_features=audio_features,
) )
if audio_embeds is not None: if audio_embeds is not None:
...@@ -1298,7 +1299,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1298,7 +1299,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if audio_input["type"] == "audio_embeds": if audio_input["type"] == "audio_embeds":
return audio_input["data"] return audio_input["data"]
audio_features = audio_input["data"] audio_features = audio_input["audio_features"]
# (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple examples) and the second dim is the multi-audio dim
# (e.g. multiple audios in the same example) # (e.g. multiple audios in the same example)
...@@ -1315,8 +1316,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1315,8 +1316,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> Phi4MMImagePixelInputs | None: ) -> Phi4MMImagePixelInputs | None:
image_pixel_values: NestedTensors = kwargs.get("image_pixel_values") pixel_values = kwargs.get("image_pixel_values")
if image_pixel_values is None: if pixel_values is None:
return None return None
image_sizes = kwargs.get("image_sizes") image_sizes = kwargs.get("image_sizes")
...@@ -1328,52 +1329,9 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1328,52 +1329,9 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
and num_img_tokens is not None and num_img_tokens is not None
), "Missing image inputs" ), "Missing image inputs"
if is_list_of(image_pixel_values, torch.Tensor):
assert all(p.dim() == 5 for p in image_pixel_values), (
"Incorrect image inputs"
)
# list len is batch_size.
# each tensor has dimension: num_img_per_example, num_hd_patches,
# channels, height, width.
# need to pad along num_hd_patches.
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
image_pixel_values = cat_with_pad(image_pixel_values, dim=0)
elif isinstance(image_pixel_values, torch.Tensor):
# dimension: batch_size, num_img_per_example, num_hd_patches,
# channels, height, width.
# we flatten first 2 dims to make it a single large batch for
# SigLIP Encoder.
assert image_pixel_values.dim() == 6, "Incorrect image inputs"
image_pixel_values = image_pixel_values.flatten(0, 1)
else:
raise ValueError("Incorrect image_pixel_values inputs")
if isinstance(image_attention_mask, list):
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
elif isinstance(image_attention_mask, torch.Tensor):
image_attention_mask = image_attention_mask.flatten(0, 1)
else:
raise ValueError("Incorrect image_attention_mask inputs")
if isinstance(image_sizes, list):
image_sizes = torch.cat(image_sizes, dim=0)
elif isinstance(image_sizes, torch.Tensor):
image_sizes = image_sizes.flatten(0, 1)
else:
raise ValueError("Incorrect image_sizes inputs")
if isinstance(num_img_tokens, list):
num_img_tokens = [
n for num_tensor in num_img_tokens for n in num_tensor.tolist()
]
elif isinstance(num_img_tokens, torch.Tensor):
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
else:
raise ValueError("Incorrect num_img_tokens inputs")
return Phi4MMImagePixelInputs( return Phi4MMImagePixelInputs(
type="pixel_values", type="pixel_values",
data=image_pixel_values, pixel_values=pixel_values,
image_sizes=image_sizes, image_sizes=image_sizes,
image_attention_mask=image_attention_mask, image_attention_mask=image_attention_mask,
num_img_tokens=num_img_tokens, num_img_tokens=num_img_tokens,
...@@ -1405,7 +1363,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1405,7 +1363,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
image_embeds = image_input["image_embeds"].type(self.visual.dtype) image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else: else:
dtype = next(self.image_embed.parameters()).dtype dtype = next(self.image_embed.parameters()).dtype
pixel_values = image_input["data"].to(dtype) pixel_values = image_input["pixel_values"].to(dtype)
image_sizes = image_input["image_sizes"] image_sizes = image_input["image_sizes"]
image_attention_mask = image_input["image_attention_mask"] image_attention_mask = image_input["image_attention_mask"]
image_embeds = self.image_embed( image_embeds = self.image_embed(
......
...@@ -50,13 +50,12 @@ from vllm.multimodal.processing import ( ...@@ -50,13 +50,12 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding from .phi4mm_audio import AudioEmbedding
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
# <|endoftext10|> (see vocab.json in hf model) # <|endoftext10|> (see vocab.json in hf model)
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010 _IMAGE_PLACEHOLDER_TOKEN_ID = 200010
...@@ -467,7 +466,7 @@ class Phi4MMImagePixelInputs(TensorSchema): ...@@ -467,7 +466,7 @@ class Phi4MMImagePixelInputs(TensorSchema):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: Annotated[ pixel_values: Annotated[
torch.Tensor | list[torch.Tensor], torch.Tensor | list[torch.Tensor],
TensorShape( TensorShape(
"bn", "p", 3, "h", "w", dynamic_dims={"p"} "bn", "p", 3, "h", "w", dynamic_dims={"p"}
...@@ -499,7 +498,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema): ...@@ -499,7 +498,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema):
type: Literal["audio_features"] type: Literal["audio_features"]
data: Annotated[ audio_features: Annotated[
torch.Tensor | list[torch.Tensor], torch.Tensor | list[torch.Tensor],
TensorShape("bn", "t", 80, dynamic_dims={"t"}), TensorShape("bn", "t", 80, dynamic_dims={"t"}),
] ]
...@@ -986,6 +985,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -986,6 +985,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
Implements the Phi-4-multimodal-instruct model in vLLM. Implements the Phi-4-multimodal-instruct model in vLLM.
""" """
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"qkv_proj", "qkv_proj",
...@@ -1094,7 +1095,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1094,7 +1095,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if audio_features is not None: if audio_features is not None:
return Phi4MMAudioFeatureInputs( return Phi4MMAudioFeatureInputs(
type="audio_features", data=flatten_bn(audio_features) type="audio_features",
audio_features=audio_features,
) )
if audio_embeds is not None: if audio_embeds is not None:
...@@ -1119,7 +1121,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1119,7 +1121,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if audio_input["type"] == "audio_embeds": if audio_input["type"] == "audio_embeds":
return audio_input["data"] return audio_input["data"]
audio_features = audio_input["data"] audio_features = audio_input["audio_features"]
# (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple examples) and the second dim is the multi-audio dim
# (e.g. multiple audios in the same example) # (e.g. multiple audios in the same example)
...@@ -1136,8 +1138,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1136,8 +1138,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> Phi4MMImagePixelInputs | None: ) -> Phi4MMImagePixelInputs | None:
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") pixel_values = kwargs.get("input_image_embeds")
if input_image_embeds is None: if pixel_values is None:
return None return None
image_sizes = kwargs.get("image_sizes") image_sizes = kwargs.get("image_sizes")
...@@ -1149,52 +1151,9 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1149,52 +1151,9 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
and num_img_tokens is not None and num_img_tokens is not None
), "Missing image inputs" ), "Missing image inputs"
if is_list_of(input_image_embeds, torch.Tensor):
assert all(p.dim() == 5 for p in input_image_embeds), (
"Incorrect image inputs"
)
# list len is batch_size.
# each tensor has dimension: num_img_per_example, num_hd_patches,
# channels, height, width.
# need to pad along num_hd_patches.
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
input_image_embeds = cat_with_pad(input_image_embeds, dim=0)
elif isinstance(input_image_embeds, torch.Tensor):
# dimension: batch_size, num_img_per_example, num_hd_patches,
# channels, height, width.
# we flatten first 2 dims to make it a single large batch for
# SigLIP Encoder.
assert input_image_embeds.dim() == 6, "Incorrect image inputs"
input_image_embeds = input_image_embeds.flatten(0, 1)
else:
raise ValueError("Incorrect input_image_embeds inputs")
if isinstance(image_attention_mask, list):
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
elif isinstance(image_attention_mask, torch.Tensor):
image_attention_mask = image_attention_mask.flatten(0, 1)
else:
raise ValueError("Incorrect image_attention_mask inputs")
if isinstance(image_sizes, list):
image_sizes = torch.cat(image_sizes, dim=0)
elif isinstance(image_sizes, torch.Tensor):
image_sizes = image_sizes.flatten(0, 1)
else:
raise ValueError("Incorrect image_sizes inputs")
if isinstance(num_img_tokens, list):
num_img_tokens = [
n for num_tensor in num_img_tokens for n in num_tensor.tolist()
]
elif isinstance(num_img_tokens, torch.Tensor):
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
else:
raise ValueError("Incorrect num_img_tokens inputs")
return Phi4MMImagePixelInputs( return Phi4MMImagePixelInputs(
type="pixel_values", type="pixel_values",
data=input_image_embeds, pixel_values=pixel_values,
image_sizes=image_sizes, image_sizes=image_sizes,
image_attention_mask=image_attention_mask, image_attention_mask=image_attention_mask,
num_img_tokens=num_img_tokens, num_img_tokens=num_img_tokens,
...@@ -1223,7 +1182,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1223,7 +1182,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
self, image_input: Phi4MMImagePixelInputs self, image_input: Phi4MMImagePixelInputs
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
dtype = next(self.vision_encoder.parameters()).dtype dtype = next(self.vision_encoder.parameters()).dtype
pixel_values = image_input["data"].to(dtype) pixel_values = image_input["pixel_values"].to(dtype)
image_sizes = image_input["image_sizes"] image_sizes = image_input["image_sizes"]
image_attention_mask = image_input["image_attention_mask"] image_attention_mask = image_input["image_attention_mask"]
image_embeds = self.vision_encoder( image_embeds = self.vision_encoder(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment