Unverified Commit 04de9057 authored by whyiug's avatar whyiug Committed by GitHub
Browse files

[Model] support input image embedding for minicpmv (#9237)

parent 07c11cf4
...@@ -378,7 +378,7 @@ Text Generation ...@@ -378,7 +378,7 @@ Text Generation
- ✅︎ - ✅︎
* - :code:`MiniCPMV` * - :code:`MiniCPMV`
- MiniCPM-V - MiniCPM-V
- Image\ :sup:`+` - Image\ :sup:`E+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎
......
...@@ -57,12 +57,19 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT ...@@ -57,12 +57,19 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
print(generated_text) print(generated_text)
# Inference with image embeddings as input with additional parameters # Inference with image embeddings as input with additional parameters
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding. # Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters.
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM) mm_data = {}
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
# For Qwen2VL, image_grid_thw is needed to calculate positional encoding.
mm_data['image'] = {
"image_embeds": image_embeds,
"image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3),
}
# For MiniCPM-V, image_size_list is needed to calculate details of the sliced image.
mm_data['image'] = { mm_data['image'] = {
"image_embeds": image_embeds, "image_embeds": image_embeds,
"image_grid_thw": image_grid_thw, "image_size_list": [image.size] # list of image sizes
} }
outputs = llm.generate({ outputs = llm.generate({
"prompt": prompt, "prompt": prompt,
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
import math import math
import re import re
from functools import partial from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
TypedDict) Tuple, TypedDict, Union)
import torch import torch
import torch.types import torch.types
...@@ -65,10 +65,12 @@ _KEYS_TO_MODIFY_MAPPING = { ...@@ -65,10 +65,12 @@ _KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head", "llm.lm_head": "lm_head",
} }
RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVImageInput(TypedDict):
class MiniCPMVRawImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds.""" """Input mapper input with auxiliary data for computing image bounds."""
image: Image.Image image: RawImageType
# Image bounds token ids in 0-dim scaler tensor. # Image bounds token ids in 0-dim scaler tensor.
im_start_id: torch.Tensor im_start_id: torch.Tensor
...@@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict): ...@@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict):
class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImagePixelInputs(TypedDict):
pixel_values: List[torch.Tensor] type: Literal["pixel_values"]
data: List[torch.Tensor]
""" """
Shape: `(batch_size * num_images, num_channels, height, width)` Shape: `(batch_size * num_images, num_channels, height, width)`
...@@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict): ...@@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict):
""" """
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs]
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
...@@ -194,22 +218,22 @@ class Resampler2_5(BaseResampler): ...@@ -194,22 +218,22 @@ class Resampler2_5(BaseResampler):
def _build_image_input(ctx: InputContext, def _build_image_input(ctx: InputContext,
image: Image.Image) -> MiniCPMVImageInput: image: RawImageType) -> MiniCPMVRawImageInput:
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer, ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code) trust_remote_code=ctx.model_config.trust_remote_code)
if hasattr(tokenizer, "slice_start_id"): if hasattr(tokenizer, "slice_start_id"):
return MiniCPMVImageInput( return MiniCPMVRawImageInput(
image=image, image=image,
im_start_id=torch.tensor(tokenizer.im_start_id), im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id), im_end_id=torch.tensor(tokenizer.im_end_id),
slice_start_id=torch.tensor(tokenizer.slice_start_id), slice_start_id=torch.tensor(tokenizer.slice_start_id),
slice_end_id=torch.tensor(tokenizer.slice_end_id)) slice_end_id=torch.tensor(tokenizer.slice_end_id))
else: else:
return MiniCPMVImageInput(image=image, return MiniCPMVRawImageInput(
im_start_id=torch.tensor( image=image,
tokenizer.im_start_id), im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id)) im_end_id=torch.tensor(tokenizer.im_end_id))
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
...@@ -280,20 +304,25 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -280,20 +304,25 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
pattern = "(<image>./</image>)" pattern = "(<image>./</image>)"
images = multi_modal_data["image"] images = multi_modal_data["image"]
if isinstance(images, Image.Image):
images = [images]
image_tags = re.findall(pattern, prompt) image_tags = re.findall(pattern, prompt)
if len(image_tags) == 0: if len(image_tags) == 0:
new_token_ids = token_ids new_token_ids = token_ids
new_prompt = prompt new_prompt = prompt
else: else:
if isinstance(images, dict):
image_size_list = images.get("image_size_list")
images = [images.get("image_embeds")]
else:
if isinstance(images, Image.Image):
images = [images]
image_size_list = [image.size for image in images]
text_chunks = prompt.split(pattern) text_chunks = prompt.split(pattern)
new_prompt_chunks: List[str] = [] new_prompt_chunks: List[str] = []
for i in range(len(images)): for i in range(len(image_size_list)):
new_prompt_chunks += [ new_prompt_chunks += [
text_chunks[i], text_chunks[i],
get_placeholder(images[i].size, i) get_placeholder(image_size_list[i], i)
] ]
new_prompt_chunks.append(text_chunks[-1]) new_prompt_chunks.append(text_chunks[-1])
new_prompt = "".join(new_prompt_chunks) new_prompt = "".join(new_prompt_chunks)
...@@ -323,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): ...@@ -323,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
if not isinstance(data, list): if not isinstance(data, list):
raise ValueError( raise ValueError(
"Image input must be list of MiniCPMVImageInput, got (%s)", data) "Image input must be list of MiniCPMVImageInput, got (%s)", data)
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \ if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
.data batch_data = {
"image_embeds": data[0]['image'],
}
else:
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data
if len(data) > 0: if len(data) > 0:
batch_data["im_start_id"] = data[0]["im_start_id"] batch_data["im_start_id"] = data[0]["im_start_id"]
...@@ -380,7 +415,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -380,7 +415,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_embedding( def get_embedding(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImagePixelInputs], image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"): if hasattr(self.config, "scale_emb"):
...@@ -389,7 +424,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -389,7 +424,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if image_inputs is None: # No image if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device) vision_hidden_states = torch.tensor([], device=input_ids.device)
else: else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs) if image_inputs["type"] == "image_embeds":
vision_hidden_states = (image_inputs["data"].type(
vlm_embedding.dtype).to(vlm_embedding.device))
else:
vision_hidden_states = self.get_vision_hidden_states(
image_inputs)
# See NOTE in _parse_and_validate_inputs # See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"] image_bounds = image_inputs["image_bounds"]
...@@ -440,9 +480,23 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -440,9 +480,23 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
**kwargs: object, **kwargs: object,
) -> Optional[MiniCPMVImagePixelInputs]: ) -> Optional[MiniCPMVImageInputs]:
pixel_values = kwargs.pop("pixel_values", []) pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", []) tgt_sizes = kwargs.pop("tgt_sizes", [])
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if image_embeds is not None:
return MiniCPMVImageEmbeddingInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=image_embeds,
type="image_embeds",
)
if not isinstance(pixel_values, (torch.Tensor, list)): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
...@@ -477,10 +531,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -477,10 +531,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if len(pixel_values_flat) == 0: if len(pixel_values_flat) == 0:
return None return None
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
if im_start_id is None: if im_start_id is None:
return None return None
...@@ -488,8 +538,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -488,8 +538,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
image_bounds=self._get_image_bounds(input_ids, im_start_id, image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id, im_end_id, slice_start_id,
slice_end_id), slice_end_id),
pixel_values=pixel_values_flat, data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat), tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
) )
def forward( def forward(
...@@ -610,8 +661,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -610,8 +661,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def is_default_weight_loading(self, name: str) -> bool: def is_default_weight_loading(self, name: str) -> bool:
...@@ -705,9 +756,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -705,9 +756,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
res.append(self.resampler(vision_embedding, tgt_size)) res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res) return torch.vstack(res)
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["data"]
return self.get_vision_embedding(pixel_values) return self.get_vision_embedding(pixel_values)
...@@ -793,9 +844,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -793,9 +844,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
vision_embedding = self.resampler(vision_embedding, tgt_sizes) vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding return vision_embedding
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"] tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
...@@ -909,9 +960,9 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -909,9 +960,9 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
) )
return vision_embedding return vision_embedding
def get_vision_hidden_states( def get_vision_hidden_states(self,
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"] pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"] tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment