Unverified Commit 61c6a5a7 authored by Rémi Delacourt's avatar Rémi Delacourt Committed by GitHub
Browse files

[VLM] Merged multi-modal processor for Pixtral (#12211)


Signed-off-by: default avatarremi <remi@mistral.ai>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 74bc397b
......@@ -43,12 +43,18 @@ from vllm.sampling_params import SamplingParams
# python demo.py advanced
def run_simple_demo():
def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=8192)
# Lower max_num_seqs or max_model_len on low-VRAM GPUs.
llm = LLM(model=model_name, tokenizer_mode="mistral")
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
llm = LLM(
model=model_name,
tokenizer_mode="mistral",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300"
......@@ -76,7 +82,7 @@ def run_simple_demo():
print(outputs[0].outputs[0].text)
def run_advanced_demo():
def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409"
max_img_per_msg = 5
max_tokens_per_img = 4096
......@@ -87,6 +93,7 @@ def run_advanced_demo():
tokenizer_mode="mistral",
limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = "Describe the following image."
......@@ -153,14 +160,19 @@ def main():
help="Specify the demo mode: 'simple' or 'advanced'",
)
parser.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, disables caching of multi-modal preprocessor/mapper.')
args = parser.parse_args()
if args.mode == "simple":
print("Running simple demo...")
run_simple_demo()
run_simple_demo(args)
elif args.mode == "advanced":
print("Running advanced demo...")
run_advanced_demo()
run_advanced_demo(args)
if __name__ == "__main__":
......
......@@ -2,17 +2,23 @@
import copy
from functools import partial
from typing import Optional
from typing import Optional, Union
import numpy as np
import pytest
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import ProcessingCache
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
......@@ -85,14 +91,6 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
}
tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches):
mm_data = {
k:
......@@ -115,6 +113,49 @@ def _test_processing_correctness(
elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0]
if isinstance(tokenizer, MistralTokenizer):
_test_processing_correctness_mistral(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
else:
_test_processing_correctness_hf(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
def _test_processing_correctness_hf(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
token_prompt = tokenizer.encode(prompt)
baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
......@@ -126,32 +167,77 @@ def _test_processing_correctness(
hf_processor_mm_kwargs={},
)
assert _drop_mm_kwargs_keys(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
cached_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
assert _inputs_equal(
baseline_result,
cached_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _drop_mm_kwargs_keys(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
baseline_tokenized_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
assert _inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
def _test_processing_correctness_mistral(
model_config: ModelConfig,
tokenizer: MistralTokenizer,
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
images = [images]
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=prompt),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
token_prompt = res.tokens
# Mistral chat outputs tokens directly, rather than text prompts
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _drop_mm_kwargs_keys(
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
cached_tokenized_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
assert _inputs_equal(
baseline_tokenized_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
# yapf: disable
......@@ -173,6 +259,7 @@ def _test_processing_correctness(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
......@@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
)
def _drop_mm_kwargs_keys(result: dict,
ignore_mm_keys: Optional[list[str]] = None) -> dict:
def _inputs_equal(
a: MultiModalInputs,
b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
b, ignore_mm_keys)
def _drop_mm_kwargs_keys(
result: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs'].
This is mainly to avoid doing exact match of audio_features in ultravox.
......
......@@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_embeds)`
Shape: `(batch_size, num_images, num_embeds)`
"""
num_crops: Union[torch.Tensor, list[torch.Tensor]]
num_patches: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
......@@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_crops = torch.tensor([(ncols + 1) * nrows
num_patches = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# flatten the list and later use `num_crops` to get per-image masks.
embed_is_patch = torch.tensor(
flatten_2d_lists([([True] * ncols + [False]) * nrows
for ncols, nrows in tile_sizes]))
processed_outputs["num_crops"] = num_crops
# later use `num_patches` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_patches"] = num_patches
processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["feat_is_patch"] = embed_is_patch
return processed_outputs
......@@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict(
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
......@@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral":
feat_is_patch = kwargs.pop("feat_is_patch")
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
num_patches=num_patches,
)
return LlavaImagePixelInputs(
......@@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
image_features,
leaf,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
def _process_image_pixels(
self,
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
......@@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> list[torch.Tensor]:
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `feat_is_patch`. This work
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
feat_is_patch = feat_is_patch.view(-1)
embed_is_patch = embed_is_patch.view(-1)
expanded_embedding = torch.full(
(sum(num_crops), *features.shape[1:]),
torch.nan,
dtype=features.dtype).to(features.device)
expanded_embedding[feat_is_patch] = features
num_patches_per_image: list[int] = num_patches.tolist()
num_crops_per_image = num_crops.tolist()
feats_per_image = expanded_embedding.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
embed_dim = expanded_embedding.shape[-1]
num_embeds = embed_is_patch.shape[0]
embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
return embeds_in_batch
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
......@@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings
nested_emb = [
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)
vision_embeddings,
image_input["num_patches"],
image_input["embed_is_patch"],
))
def get_input_embeddings(
self,
......@@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, cast(NestedTensors,
patch_embeddings),
self.config.image_token_index)
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
self.config.image_token_index,
)
return inputs_embeds
def forward(
......
......@@ -1585,15 +1585,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_features = self._process_image_input(image_input)
nested_embeds = [
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
image_features,
image_input["feat_is_patch"],
image_input["num_crops"],
image_input["embed_is_patch"],
)
]
return flatten_2d_lists(nested_embeds)
))
def get_input_embeddings(
self,
......
# SPDX-License-Identifier: Apache-2.0
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
......@@ -17,7 +16,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptReplacement,
PromptInsertion, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -144,7 +143,7 @@ class PaliGemmaMultiModalProcessor(
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
......
This diff is collapsed.
......@@ -77,7 +77,9 @@ class PromptIndexTargets:
else:
if isinstance(prefix, str):
# Make both `list[int]`
prefix = encode_tokens(tokenizer, prefix)
prefix = encode_tokens(tokenizer,
prefix,
add_special_tokens=False)
match_idx = len(prefix)
return match_idx if prompt[:match_idx] == prefix else None
......@@ -318,7 +320,7 @@ def _cached_encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
return encode_tokens(tokenizer,
text,
......@@ -330,7 +332,7 @@ def _cached_decode(
tokenizer: AnyTokenizer,
token_ids: tuple[int, ...],
*,
skip_special_tokens: bool = False,
skip_special_tokens: Optional[bool] = None,
) -> str:
return decode_tokens(tokenizer,
list(token_ids),
......@@ -395,7 +397,9 @@ class _BoundPromptSequence:
def token_ids(self) -> list[int]:
if self._token_ids is None:
assert self._text is not None
self._token_ids = _cached_encode(self.tokenizer, self._text)
self._token_ids = _cached_encode(self.tokenizer,
self._text,
add_special_tokens=False)
return self._token_ids
......@@ -1046,7 +1050,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptUpdate]:
) -> Sequence[PromptUpdate]:
"""
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.
......
......@@ -34,13 +34,20 @@ def decode_tokens(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
skip_special_tokens: Optional[bool] = None,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
:code:`tokenizer.decode(token_ids, ...)`.
:code:`skip_special_tokens=None` means to use the backend's default
settings.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
if skip_special_tokens is not None:
return tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
return tokenizer.decode(token_ids)
def encode_tokens(
......@@ -51,10 +58,14 @@ def encode_tokens(
) -> list[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
:code:`tokenizer.encode(text, ...)`.
:code:`add_special_tokens=None` means to use the backend's default
settings.
"""
if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text)
......
......@@ -845,7 +845,7 @@ def is_list_of(
assert_never(check)
def flatten_2d_lists(lists: list[list[T]]) -> list[T]:
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment