Unverified Commit e01ab595 authored by whyiug's avatar whyiug Committed by GitHub
Browse files

[Model] support input embeddings for qwen2vl (#8856)

parent f13a07b1
...@@ -281,7 +281,7 @@ Multimodal Language Models ...@@ -281,7 +281,7 @@ Multimodal Language Models
- -
* - :code:`Qwen2VLForConditionalGeneration` * - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL - Qwen2-VL
- Image\ :sup:`+` / Video\ :sup:`+` - Image\ :sup:`E+` / Video\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
- -
* - :code:`UltravoxModel` * - :code:`UltravoxModel`
......
...@@ -60,7 +60,24 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT ...@@ -60,7 +60,24 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
# Inference with image embeddings as input with additional parameters
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
mm_data['image'] = {
"image_embeds": image_embeds,
"image_grid_thw": image_grid_thw,
}
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": mm_data,
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Batch inference # Batch inference
image_1 = PIL.Image.open(...) image_1 = PIL.Image.open(...)
image_2 = PIL.Image.open(...) image_2 = PIL.Image.open(...)
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Union) Tuple, Type, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -76,19 +76,31 @@ logger = init_logger(__name__) ...@@ -76,19 +76,31 @@ logger = init_logger(__name__)
# === Vision Inputs === # # === Vision Inputs === #
class Qwen2VLImageInputs(TypedDict): class Qwen2VLImagePixelInputs(TypedDict):
pixel_values: torch.Tensor type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: """Shape:
`(num_patches, num_channels * patch_size * patch_size)` `(num_patches, num_channels * patch_size * patch_size)`
""" """
image_grid_thw: torch.Tensor image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)` """Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format. This should be in `(grid_t, grid_h, grid_w)` format.
""" """
class Qwen2VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
Qwen2VLImageEmbeddingInputs]
class Qwen2VLVideoInputs(TypedDict): class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor pixel_values_videos: torch.Tensor
"""Shape: """Shape:
...@@ -567,6 +579,11 @@ def mm_input_mapper_for_qwen2_vl( ...@@ -567,6 +579,11 @@ def mm_input_mapper_for_qwen2_vl(
data_type_key: str, data_type_key: str,
) -> MultiModalInputs: ) -> MultiModalInputs:
"""Input mapper for Qwen2-VL.""" """Input mapper for Qwen2-VL."""
if data_type_key == "image" and isinstance(data, dict):
return MultiModalInputs({
"image_embeds": data.get("image_embeds"),
"image_grid_thw": data.get("image_grid_thw"),
})
model_config = ctx.model_config model_config = ctx.model_config
image_processor = cached_get_image_processor( image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code) model_config.model, trust_remote_code=model_config.trust_remote_code)
...@@ -739,6 +756,48 @@ def _get_llm_num_vision_tokens( ...@@ -739,6 +756,48 @@ def _get_llm_num_vision_tokens(
return llm_num_vision_tokens return llm_num_vision_tokens
def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
data_type_key: str, image_processor: Any,
prompt_token_ids: List[int]) -> List[int]:
"""
Expand pad tokens for multi-modal inputs (e.g., images or videos).
Args:
inputs (list): The multi-modal inputs (e.g., images or videos).
token_id (int): The token ID used to represent the multi-modal input.
make_batched_fn (Callable): A function to batch the inputs.
data_type_key (str): The type of the multi-modal input.
image_processor (Any): The image processor used to process the inputs.
prompt_token_ids (List[int]): The list of token IDs in the prompt.
Returns:
List[int]: The list of token IDs for the multi-modal inputs.
"""
indices = [
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
]
inputs = make_batched_fn(inputs)
assert len(indices) == len(inputs)
prompt_token_ids_with_data = []
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens(
[data] if data_type_key == "image" else data,
data_type_key=data_type_key,
image_processor=image_processor,
)
if cnt == 0:
end_idx = indices[cnt]
non_data_tokens = prompt_token_ids[:end_idx]
else:
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
1:indices[cnt]]
prompt_token_ids_with_data.extend(non_data_tokens)
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
return prompt_token_ids_with_data
def input_processor_for_qwen2_vl(ctx: InputContext, def input_processor_for_qwen2_vl(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs: llm_inputs: LLMInputs) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data", None) multi_modal_data = llm_inputs.get("multi_modal_data", None)
...@@ -775,62 +834,38 @@ def input_processor_for_qwen2_vl(ctx: InputContext, ...@@ -775,62 +834,38 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
)["input_ids"] )["input_ids"]
# Expand image pad tokens. # Expand image pad tokens.
if image_inputs is not None: if image_inputs is not None:
image_indices = [ if isinstance(image_inputs, dict):
idx for idx, token in enumerate(prompt_token_ids) prompt_token_ids_with_image = []
if token == hf_config.image_token_id image_indices = [
] idx for idx, token in enumerate(prompt_token_ids)
image_inputs = make_batched_images(image_inputs) if token == hf_config.image_token_id
assert len(image_indices) == len(image_inputs) ]
image_cnt = len(image_indices)
prompt_token_ids_with_image = [] embed_dim = image_inputs.get('image_embeds').size(0)
for image_cnt, image in enumerate(image_inputs): assert embed_dim % image_cnt == 0
num_image_tokens = _get_llm_num_vision_tokens( num_pad_tokens = embed_dim // image_cnt
[image], for idx, token in enumerate(prompt_token_ids):
data_type_key="image", if idx in image_indices:
image_processor=image_processor, prompt_token_ids_with_image.extend([token] *
) num_pad_tokens)
if image_cnt == 0: else:
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]] prompt_token_ids_with_image.append(token)
else: prompt_token_ids = prompt_token_ids_with_image
non_image_tokens = prompt_token_ids[image_indices[image_cnt - else:
1] + prompt_token_ids = _expand_pad_tokens(image_inputs,
1:image_indices[image_cnt]] hf_config.image_token_id,
prompt_token_ids_with_image.extend(non_image_tokens) make_batched_images, "image",
prompt_token_ids_with_image.extend( image_processor,
hf_config.image_token_id for _ in range(num_image_tokens)) prompt_token_ids)
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_image
# Expand video pad tokens.
if video_inputs is not None: if video_inputs is not None:
video_indices = [ prompt_token_ids = _expand_pad_tokens(video_inputs,
idx for idx, token in enumerate(prompt_token_ids) hf_config.video_token_id,
if token == hf_config.video_token_id make_batched_videos, "video",
] image_processor,
video_inputs = make_batched_videos(video_inputs) prompt_token_ids)
assert len(video_indices) == len(video_inputs)
prompt_token_ids_with_video = []
for video_cnt, video in enumerate(video_inputs):
num_video_tokens = _get_llm_num_vision_tokens(
video,
data_type_key="video",
image_processor=image_processor,
)
if video_cnt == 0:
non_video_tokens = prompt_token_ids[:video_indices[video_cnt]]
else:
non_video_tokens = prompt_token_ids[video_indices[video_cnt -
1] +
1:video_indices[video_cnt]]
prompt_token_ids_with_video.extend(non_video_tokens)
prompt_token_ids_with_video.extend(
hf_config.video_token_id for _ in range(num_video_tokens))
prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_video
return LLMInputs( return LLMInputs(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
...@@ -910,22 +945,32 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -910,22 +945,32 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None) image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None: if pixel_values is None and image_embeds is None:
return None return None
pixel_values = self._validate_and_reshape_mm_tensor( if pixel_values is not None:
pixel_values, "image pixel values") pixel_values = self._validate_and_reshape_mm_tensor(
image_grid_thw = self._validate_and_reshape_mm_tensor( pixel_values, "image pixel values")
image_grid_thw, "image grid_thw") image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. " raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
return Qwen2VLImageInputs(pixel_values=pixel_values, return Qwen2VLImagePixelInputs(type="pixel_values",
image_grid_thw=image_grid_thw) data=pixel_values,
image_grid_thw=image_grid_thw)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
data=image_embeds)
def _parse_and_validate_video_input( def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
...@@ -947,7 +992,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -947,7 +992,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def _process_image_input(self, def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor: image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) if image_input["type"] == "image_embeds":
return image_input["data"].type(self.visual.dtype)
pixel_values = image_input["data"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"]) grid_thw=image_input["image_grid_thw"])
return image_embeds return image_embeds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment