Unverified Commit 17699280 authored by kYLe's avatar kYLe Committed by GitHub
Browse files

[Model] Update Paligemma multimodal processing with PromptUpdate (#14015)


Signed-off-by: default avatarKyle Huang <kylhuang@nvidia.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent ed6ea065
......@@ -842,13 +842,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `PaliGemmaForConditionalGeneration`\*
* PaliGemma, PaliGemma 2
- * `PaliGemmaForConditionalGeneration`
* PaliGemma (see note), PaliGemma 2 (see note)
* T + I<sup>E</sup>
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
*
* ✅︎
*
* ✅︎
- * `Phi3VForCausalLM`
* Phi-3-Vision, Phi-3.5-Vision
* T + I<sup>E+</sup>
......
......@@ -116,9 +116,8 @@ VLM_TEST_SETTINGS = {
"pixel_values"
),
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
dtype=("half" if current_platform.is_cpu() or current_platform.is_rocm()
else ("half", "float")),
marks=[pytest.mark.core_model],
dtype="bfloat16",
marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
),
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
# once we upgraded to transformers>=4.49.0.
......
......@@ -175,6 +175,8 @@ def _test_processing_correctness(
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_4",
"openai/whisper-large-v3",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
......
......@@ -5,22 +5,26 @@ from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
import torch
from torch import nn
from transformers import PaliGemmaConfig
from transformers import BatchFeature, PaliGemmaConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptReplacement,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
......@@ -46,97 +50,152 @@ PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
PaliGemmaImageEmbeddingInputs]
def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
return get_max_siglip_image_tokens(vision_config)
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
def input_processor_for_paligemma(ctx: InputContext,
inputs: DecoderOnlyInputs):
class PaliGemmaMultiModalProjector(nn.Module):
"""
The correct prompt format needs to be:
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
""" # noqa
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PaliGemmaConfig)
tokenizer = cached_tokenizer_from_config(model_config)
image_feature_size = hf_config.text_config.num_image_tokens
image_token_str = tokenizer.decode(hf_config.image_token_index)
bos_token = tokenizer.decode(hf_config.bos_token_id)
image_token_str_pad = image_token_str * image_feature_size
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
class PaliGemmaProcessingInfo(BaseProcessingInfo):
orig_prompt = inputs.get("prompt")
orig_prompt_ids = inputs.get("prompt_token_ids")
def get_hf_config(self):
return self.ctx.get_hf_config(PaliGemmaConfig)
if orig_prompt is not None and image_token_str in orig_prompt:
logger.warning(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
" documented on HuggingFace.", image_token_str)
orig_prompt = orig_prompt.replace(image_token_str, "")
orig_prompt_ids.remove(hf_config.image_token_index)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
# The PaliGemma 2 tokenizer does not include a starting BOS token
if orig_prompt_ids[0] != hf_config.bos_token_id:
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
return get_max_siglip_image_tokens(vision_config)
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
class PaliGemmaDummyInputsBuilder(
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
class PaliGemmaMultiModalProcessor(
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if not mm_data:
prompt_ids = tokenizer.encode(prompt)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
tokenizer = self.info.get_tokenizer()
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
# Paligemma 1 and 2 have different tokenizer.add_bos_token
# Insert <image>*n + <bos> after <bos> for Paligemma 1
# Insert <image>*n + <bos> for Paligemma 2
return [
PromptInsertion(
modality="image",
target=PromptIndexTargets.prefix(
[bos_token_id] if tokenizer.add_bos_token else []),
insertion=PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
),
)
]
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer()
newline_prompt = "\n"
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
# Force to add newline at the end of prompt for paligemma's format
# This step can NOT be replacemented by current PromptUpdate methods
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
prompt_token_ids.append(newline_token_id)
mm_inputs["prompt_token_ids"] = prompt_token_ids
mm_inputs["prompt"] += newline_prompt
return mm_inputs
@MULTIMODAL_REGISTRY.register_processor(
PaliGemmaMultiModalProcessor,
info=PaliGemmaProcessingInfo,
dummy_inputs=PaliGemmaDummyInputsBuilder)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsV0Only):
SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment