Unverified Commit e6e42e4b authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Core][VLM] Support image embeddings as input (#6613)

parent ec2affa8
...@@ -49,6 +49,17 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI ...@@ -49,6 +49,17 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI
"multi_modal_data": {"image": image}, "multi_modal_data": {"image": image},
}) })
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Inference with image embeddings as input
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": image_embeds},
})
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
......
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
"cherry_blossom":
"USER: <image>\nWhat is the season?\nASSISTANT:",
})
models = [
"llava-hf/llava-1.5-7b-hf",
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
assert output_str[0] == " "
hf_output_str = output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# vLLM to load from image embeddings
vllm_images = [asset.image_embeds for asset in image_assets]
# transformers to load from PIL images
hf_images = [asset.pil_image for asset in image_assets]
vllm_inputs_per_image = [(
[prompt for _ in size_factors],
[image for _ in size_factors],
) for image, prompt in zip(vllm_images, HF_IMAGE_PROMPTS)]
hf_inputs_per_image = [(
[prompt for _ in size_factors],
[image for _ in size_factors],
) for image, prompt in zip(hf_images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in vllm_inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in hf_inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import Literal
import torch
from PIL import Image from PIL import Image
from vllm.assets.base import get_vllm_public_assets from vllm.assets.base import get_vllm_public_assets
...@@ -18,3 +19,12 @@ class ImageAsset: ...@@ -18,3 +19,12 @@ class ImageAsset:
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
s3_prefix=VLM_IMAGES_DIR) s3_prefix=VLM_IMAGES_DIR)
return Image.open(image_path) return Image.open(image_path)
@property
def image_embeds(self) -> torch.Tensor:
"""
Image embeddings, only used for testing purposes with llava 1.5.
"""
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
s3_prefix=VLM_IMAGES_DIR)
return torch.load(image_path)
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -28,6 +28,29 @@ _KEYS_TO_MODIFY_MAPPING = { ...@@ -28,6 +28,29 @@ _KEYS_TO_MODIFY_MAPPING = {
"language_model.model": "language_model", "language_model.model": "language_model",
} }
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
class Blip2QFormerMultiHeadAttention(nn.Module): class Blip2QFormerMultiHeadAttention(nn.Module):
...@@ -375,20 +398,6 @@ class Blip2QFormerModel(nn.Module): ...@@ -375,20 +398,6 @@ class Blip2QFormerModel(nn.Module):
return sequence_output return sequence_output
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
Blip2ImageInputs = Blip2ImagePixelInputs
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int: def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
return hf_config.num_query_tokens return hf_config.num_query_tokens
...@@ -506,18 +515,31 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision): ...@@ -506,18 +515,31 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]: self, **kwargs: object) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None: if pixel_values is None and image_embeds is None:
return None return None
if not isinstance(pixel_values, torch.Tensor): if pixel_values is not None:
raise ValueError("Incorrect type of pixel values. " if not isinstance(pixel_values, torch.Tensor):
f"Got type: {type(pixel_values)}") raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Blip2ImagePixelInputs( return Blip2ImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
) )
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(self, vision_model: BlipVisionModel, def _image_pixels_to_features(self, vision_model: BlipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor: pixel_values: torch.Tensor) -> torch.Tensor:
...@@ -538,6 +560,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision): ...@@ -538,6 +560,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self, def _process_image_input(self,
image_input: Blip2ImageInputs) -> torch.Tensor: image_input: Blip2ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None assert self.vision_model is not None
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
......
...@@ -88,7 +88,13 @@ def input_processor_for_clip( ...@@ -88,7 +88,13 @@ def input_processor_for_clip(
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None: if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config) image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
......
...@@ -234,7 +234,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision): ...@@ -234,7 +234,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
def _parse_and_validate_image_input(self, **kwargs: object): def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None) image_patches = kwargs.pop("image_patches", None)
if isinstance(image_patches, torch.Tensor): if isinstance(image_patches, torch.Tensor):
...@@ -249,6 +250,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision): ...@@ -249,6 +250,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
data=image_patches) data=image_patches)
return None return None
def _process_image_input(
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -261,8 +269,7 @@ class FuyuForCausalLM(nn.Module, SupportsVision): ...@@ -261,8 +269,7 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings, _ = self.vision_embed_tokens( vision_embeddings = self._process_image_input(image_input)
image_input["data"])
inputs_embeds = self.language_model.model.embed_tokens(input_ids) inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings, vision_embeddings,
......
...@@ -50,6 +50,19 @@ class InternVLImagePixelInputs(TypedDict): ...@@ -50,6 +50,19 @@ class InternVLImagePixelInputs(TypedDict):
""" """
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size): def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
...@@ -193,8 +206,10 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -193,8 +206,10 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
# add thumbnail image if num_blocks > 1 # add thumbnail image if num_blocks > 1
if hf_config.use_thumbnail and num_blocks > 1: if hf_config.use_thumbnail and num_blocks > 1:
num_blocks += 1 num_blocks += 1
image_feature_size = num_blocks * num_patches
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet") image_feature_size = image_data.shape[0]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
...@@ -205,7 +220,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -205,7 +220,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None: if prompt is None:
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)
image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
new_prompt = prompt.replace('<image>', image_prompt, 1) new_prompt = prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt) new_prompt_token_ids = tokenizer.encode(new_prompt)
...@@ -378,23 +393,49 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -378,23 +393,49 @@ class InternVLChatModel(nn.Module, SupportsVision):
return data return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImagePixelInputs]: self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_token_id = kwargs.pop("image_token_id", None) image_token_id = kwargs.pop("image_token_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None: if pixel_values is None and image_embeds is None:
return None return None
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
self.img_context_token_id = image_token_id[0] self.img_context_token_id = image_token_id[0]
if not isinstance(pixel_values, (torch.Tensor, list)): if pixel_values is not None:
raise ValueError("Incorrect type of pixel values. " if not isinstance(pixel_values, (torch.Tensor, list)):
f"Got type: {type(pixel_values)}") raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"])
return InternVLImagePixelInputs( return image_embeds
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
def forward( def forward(
self, self,
...@@ -409,9 +450,9 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -409,9 +450,9 @@ class InternVLChatModel(nn.Module, SupportsVision):
if image_input is not None: if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
vit_embeds = self.extract_feature(image_input["data"]) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vit_embeds, vision_embeddings,
self.img_context_token_id) self.img_context_token_id)
input_ids = None input_ids = None
else: else:
......
...@@ -27,6 +27,24 @@ from .utils import (filter_weights, init_vllm_registered_model, ...@@ -27,6 +27,24 @@ from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings) merge_vision_embeddings)
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
# TODO(xwjiang): Run benchmark and decide if TP. # TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module): class LlavaMultiModalProjector(nn.Module):
...@@ -49,15 +67,6 @@ class LlavaMultiModalProjector(nn.Module): ...@@ -49,15 +67,6 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states return hidden_states
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
LlavaImageInputs = LlavaImagePixelInputs
def get_max_llava_image_tokens(ctx: InputContext): def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
...@@ -210,18 +219,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -210,18 +219,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]: self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None: if pixel_values is None and image_embeds is None:
return None return None
if not isinstance(pixel_values, torch.Tensor): if pixel_values is not None:
raise ValueError("Incorrect type of pixel values. " if not isinstance(pixel_values, torch.Tensor):
f"Got type: {type(pixel_values)}") raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
) )
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *, def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor: strategy: str) -> torch.Tensor:
...@@ -258,6 +279,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -258,6 +279,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self, def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor: image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
......
...@@ -60,7 +60,17 @@ class LlavaNextImagePixelInputs(TypedDict): ...@@ -60,7 +60,17 @@ class LlavaNextImagePixelInputs(TypedDict):
""" """
LlavaNextImageInputs = LlavaNextImagePixelInputs class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs]
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
...@@ -208,7 +218,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -208,7 +218,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_width=width, input_width=width,
) )
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet") image_feature_size = image_data.shape[0]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
...@@ -320,26 +330,40 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -320,26 +330,40 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return data return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]: self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None) image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None: if pixel_values is None and image_embeds is None:
return None return None
if not isinstance(pixel_values, (torch.Tensor, list)): if pixel_values is not None:
raise ValueError("Incorrect type of pixel values. " if not isinstance(pixel_values, (torch.Tensor, list)):
f"Got type: {type(pixel_values)}") raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor): if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. " raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}") f"Got type: {type(image_sizes)}")
return LlavaNextImagePixelInputs( return LlavaNextImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes), image_sizes=self._validate_image_sizes(image_sizes),
) )
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *, def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor: strategy: str) -> torch.Tensor:
...@@ -466,6 +490,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -466,6 +490,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self, self,
image_input: LlavaNextImageInputs, image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, List[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
patch_embeddings = self._process_image_pixels(image_input) patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes") image_sizes = image_input.get("image_sizes")
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -31,6 +31,25 @@ _KEYS_TO_MODIFY_MAPPING = { ...@@ -31,6 +31,25 @@ _KEYS_TO_MODIFY_MAPPING = {
} }
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class PaliGemmaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
PaliGemmaImageEmbeddingInputs]
def get_max_paligemma_image_tokens(ctx: InputContext): def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig) hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
...@@ -107,15 +126,6 @@ class PaliGemmaMultiModalProjector(nn.Module): ...@@ -107,15 +126,6 @@ class PaliGemmaMultiModalProjector(nn.Module):
return hidden_states return hidden_states
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
PaliGemmaImageInputs = PaliGemmaImagePixelInputs
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
...@@ -163,18 +173,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -163,18 +173,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None: if pixel_values is None and image_embeds is None:
return None return None
if not isinstance(pixel_values, torch.Tensor): if pixel_values is not None:
raise ValueError("Incorrect type of pixel values. " if not isinstance(pixel_values, torch.Tensor):
f"Got type: {type(pixel_values)}") raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return PaliGemmaImagePixelInputs( return PaliGemmaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
) )
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return PaliGemmaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features( def _image_pixels_to_features(
self, self,
...@@ -187,27 +209,21 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -187,27 +209,21 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return image_features return image_features
def _process_image_pixels( def _process_image_input(
self, self,
inputs: PaliGemmaImagePixelInputs, image_input: PaliGemmaImageInputs,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"] if image_input["type"] == "image_embeds":
return image_input["data"]
return self._image_pixels_to_features( assert self.vision_tower is not None
pixel_values = image_input["data"]
image_features = self._image_pixels_to_features(
self.vision_tower, self.vision_tower,
pixel_values, pixel_values,
) )
def _process_image_input(
self,
image_input: PaliGemmaImageInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input, )
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
def forward(self, def forward(self,
......
...@@ -70,6 +70,36 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, ...@@ -70,6 +70,36 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim=768) projection_dim=768)
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
class Phi3VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
class Phi3ImageEmbeddingBase(nn.Module): class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
...@@ -257,24 +287,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): ...@@ -257,24 +287,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline return image_features_hd_newline
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit) target_height = int(np.ceil(height / padding_unit) * padding_unit)
...@@ -390,7 +402,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -390,7 +402,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width=w, input_width=w,
input_height=h) input_height=h)
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet") image_feature_size = image_data.shape[0]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
...@@ -494,25 +506,55 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -494,25 +506,55 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return data return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]: self, **kwargs: object) -> Optional[Phi3VImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None) image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None: if pixel_values is None:
return None return None
if not isinstance(pixel_values, (torch.Tensor, list)): if pixel_values is None and image_embeds is None:
raise ValueError("Incorrect type of pixel values. " return None
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Phi3VImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: Phi3VImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
if not isinstance(image_sizes, torch.Tensor): assert self.vision_embed_tokens is not None
raise ValueError("Incorrect type of image sizes. " image_embeds = self.vision_embed_tokens(image_input["data"],
f"Got type: {type(image_sizes)}") image_input["image_sizes"])
return Phi3VImagePixelInputs( return image_embeds
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -524,8 +566,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -524,8 +566,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings = self.vision_embed_tokens( vision_embeddings = self._process_image_input(image_input)
image_input["data"], image_input["image_sizes"])
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings, vision_embeddings,
......
...@@ -97,7 +97,13 @@ def input_processor_for_siglip( ...@@ -97,7 +97,13 @@ def input_processor_for_siglip(
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None: if image_feature_size_override is None:
image_feature_size = get_siglip_image_feature_size(hf_config) image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_siglip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
......
...@@ -115,6 +115,7 @@ class ImagePlugin(MultiModalPlugin): ...@@ -115,6 +115,7 @@ class ImagePlugin(MultiModalPlugin):
data: object) -> MultiModalInputs: data: object) -> MultiModalInputs:
model_config = ctx.model_config model_config = ctx.model_config
# PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image): if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config) image_processor = self._get_hf_image_processor(model_config)
if image_processor is None: if image_processor is None:
...@@ -129,8 +130,10 @@ class ImagePlugin(MultiModalPlugin): ...@@ -129,8 +130,10 @@ class ImagePlugin(MultiModalPlugin):
raise raise
return MultiModalInputs(batch_data) return MultiModalInputs(batch_data)
# Image embedding
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet") return MultiModalInputs({"image_embeds": data})
raise TypeError(f"Invalid image type: {type(data)}") raise TypeError(f"Invalid image type: {type(data)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment