"vscode:/vscode.git/clone" did not exist on "75c7ad9918c7fa6d16323db963c19b5f7c45fcfd"
Unverified Commit 17699280 authored by kYLe's avatar kYLe Committed by GitHub
Browse files

[Model] Update Paligemma multimodal processing with PromptUpdate (#14015)


Signed-off-by: default avatarKyle Huang <kylhuang@nvidia.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent ed6ea065
...@@ -842,13 +842,13 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -842,13 +842,13 @@ See [this page](#generative-models) for more information on how to use generativ
* *
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `PaliGemmaForConditionalGeneration`\* - * `PaliGemmaForConditionalGeneration`
* PaliGemma, PaliGemma 2 * PaliGemma (see note), PaliGemma 2 (see note)
* T + I<sup>E</sup> * T + I<sup>E</sup>
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. * `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
* *
* ✅︎ * ✅︎
* * ✅︎
- * `Phi3VForCausalLM` - * `Phi3VForCausalLM`
* Phi-3-Vision, Phi-3.5-Vision * Phi-3-Vision, Phi-3.5-Vision
* T + I<sup>E+</sup> * T + I<sup>E+</sup>
......
...@@ -116,9 +116,8 @@ VLM_TEST_SETTINGS = { ...@@ -116,9 +116,8 @@ VLM_TEST_SETTINGS = {
"pixel_values" "pixel_values"
), ),
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
dtype=("half" if current_platform.is_cpu() or current_platform.is_rocm() dtype="bfloat16",
else ("half", "float")), marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
marks=[pytest.mark.core_model],
), ),
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL # TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
# once we upgraded to transformers>=4.49.0. # once we upgraded to transformers>=4.49.0.
......
...@@ -175,6 +175,8 @@ def _test_processing_correctness( ...@@ -175,6 +175,8 @@ def _test_processing_correctness(
"Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_4", "fixie-ai/ultravox-v0_4",
"openai/whisper-large-v3", "openai/whisper-large-v3",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
]) ])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("num_batches", [32])
......
...@@ -5,22 +5,26 @@ from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, ...@@ -5,22 +5,26 @@ from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
import torch import torch
from torch import nn from torch import nn
from transformers import PaliGemmaConfig from transformers import BatchFeature, PaliGemmaConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptReplacement,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -46,97 +50,152 @@ PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, ...@@ -46,97 +50,152 @@ PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
PaliGemmaImageEmbeddingInputs] PaliGemmaImageEmbeddingInputs]
def get_max_paligemma_image_tokens(ctx: InputContext): class PaliGemmaMultiModalProjector(nn.Module):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
return get_max_siglip_image_tokens(vision_config)
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
)
mm_data = dummy_image_for_siglip(vision_config, num_images) def __init__(self, vision_hidden_size: int, projection_dim: int):
return DummyData(seq_data, mm_data, ranges) super().__init__()
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
def input_processor_for_paligemma(ctx: InputContext, def forward(self, image_features: torch.Tensor) -> torch.Tensor:
inputs: DecoderOnlyInputs): hidden_states = self.linear(image_features)
return hidden_states
"""
The correct prompt format needs to be:
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55 class PaliGemmaProcessingInfo(BaseProcessingInfo):
""" # noqa
multi_modal_data = inputs.get("multi_modal_data") def get_hf_config(self):
if multi_modal_data is None or "image" not in multi_modal_data: return self.ctx.get_hf_config(PaliGemmaConfig)
return inputs
model_config = ctx.model_config def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
hf_config = ctx.get_hf_config(PaliGemmaConfig) return {"image": 1}
tokenizer = cached_tokenizer_from_config(model_config) def get_mm_max_tokens_per_item(
image_feature_size = hf_config.text_config.num_image_tokens self,
image_token_str = tokenizer.decode(hf_config.image_token_index) seq_len: int,
bos_token = tokenizer.decode(hf_config.bos_token_id) mm_counts: Mapping[str, int],
image_token_str_pad = image_token_str * image_feature_size ) -> Mapping[str, int]:
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size return {"image": self.get_num_image_tokens()}
orig_prompt = inputs.get("prompt") def get_num_image_tokens(self) -> int:
orig_prompt_ids = inputs.get("prompt_token_ids") hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
return get_max_siglip_image_tokens(vision_config)
if orig_prompt is not None and image_token_str in orig_prompt:
logger.warning(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
" documented on HuggingFace.", image_token_str)
orig_prompt = orig_prompt.replace(image_token_str, "")
orig_prompt_ids.remove(hf_config.image_token_index)
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n" class PaliGemmaDummyInputsBuilder(
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
# The PaliGemma 2 tokenizer does not include a starting BOS token def get_dummy_processor_inputs(
if orig_prompt_ids[0] != hf_config.bos_token_id: self,
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline num_images = mm_counts.get("image", 0)
# NOTE: Create a defensive copy of the original inputs mm_data = {
return token_inputs(prompt_token_ids=new_token_ids, "image":
prompt=new_prompt, self._get_dummy_images(width=max_image_size,
multi_modal_data=multi_modal_data) height=max_image_size,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int): class PaliGemmaMultiModalProcessor(
super().__init__() BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True) def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if not mm_data:
prompt_ids = tokenizer.encode(prompt)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
def forward(self, image_features: torch.Tensor) -> torch.Tensor: def _get_mm_fields_config(
hidden_states = self.linear(image_features) self,
return hidden_states hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
tokenizer = self.info.get_tokenizer()
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
# Paligemma 1 and 2 have different tokenizer.add_bos_token
# Insert <image>*n + <bos> after <bos> for Paligemma 1
# Insert <image>*n + <bos> for Paligemma 2
return [
PromptInsertion(
modality="image",
target=PromptIndexTargets.prefix(
[bos_token_id] if tokenizer.add_bos_token else []),
insertion=PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
),
)
]
@MULTIMODAL_REGISTRY.register_image_input_mapper() def apply(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) self,
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) prompt: Union[str, list[int]],
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer()
newline_prompt = "\n"
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
# Force to add newline at the end of prompt for paligemma's format
# This step can NOT be replacemented by current PromptUpdate methods
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
prompt_token_ids.append(newline_token_id)
mm_inputs["prompt_token_ids"] = prompt_token_ids
mm_inputs["prompt"] += newline_prompt
return mm_inputs
@MULTIMODAL_REGISTRY.register_processor(
PaliGemmaMultiModalProcessor,
info=PaliGemmaProcessingInfo,
dummy_inputs=PaliGemmaDummyInputsBuilder)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsV0Only): SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment