Unverified Commit 61c6a5a7 authored by Rémi Delacourt's avatar Rémi Delacourt Committed by GitHub
Browse files

[VLM] Merged multi-modal processor for Pixtral (#12211)


Signed-off-by: default avatarremi <remi@mistral.ai>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 74bc397b
...@@ -43,12 +43,18 @@ from vllm.sampling_params import SamplingParams ...@@ -43,12 +43,18 @@ from vllm.sampling_params import SamplingParams
# python demo.py advanced # python demo.py advanced
def run_simple_demo(): def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409" model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=8192) sampling_params = SamplingParams(max_tokens=8192)
# Lower max_num_seqs or max_model_len on low-VRAM GPUs. # Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
llm = LLM(model=model_name, tokenizer_mode="mistral") llm = LLM(
model=model_name,
tokenizer_mode="mistral",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = "Describe this image in one sentence." prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300" image_url = "https://picsum.photos/id/237/200/300"
...@@ -76,7 +82,7 @@ def run_simple_demo(): ...@@ -76,7 +82,7 @@ def run_simple_demo():
print(outputs[0].outputs[0].text) print(outputs[0].outputs[0].text)
def run_advanced_demo(): def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409" model_name = "mistralai/Pixtral-12B-2409"
max_img_per_msg = 5 max_img_per_msg = 5
max_tokens_per_img = 4096 max_tokens_per_img = 4096
...@@ -87,6 +93,7 @@ def run_advanced_demo(): ...@@ -87,6 +93,7 @@ def run_advanced_demo():
tokenizer_mode="mistral", tokenizer_mode="mistral",
limit_mm_per_prompt={"image": max_img_per_msg}, limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img, max_model_len=max_img_per_msg * max_tokens_per_img,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
prompt = "Describe the following image." prompt = "Describe the following image."
...@@ -153,14 +160,19 @@ def main(): ...@@ -153,14 +160,19 @@ def main():
help="Specify the demo mode: 'simple' or 'advanced'", help="Specify the demo mode: 'simple' or 'advanced'",
) )
parser.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, disables caching of multi-modal preprocessor/mapper.')
args = parser.parse_args() args = parser.parse_args()
if args.mode == "simple": if args.mode == "simple":
print("Running simple demo...") print("Running simple demo...")
run_simple_demo() run_simple_demo(args)
elif args.mode == "advanced": elif args.mode == "advanced":
print("Running advanced demo...") print("Running advanced demo...")
run_advanced_demo() run_advanced_demo(args)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -2,17 +2,23 @@ ...@@ -2,17 +2,23 @@
import copy import copy
from functools import partial from functools import partial
from typing import Optional from typing import Optional, Union
import numpy as np import numpy as np
import pytest import pytest
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.processing import ProcessingCache from vllm.multimodal.inputs import MultiModalInputs
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from ....multimodal.utils import random_audio, random_image, random_video from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS from ...registry import HF_EXAMPLE_MODELS
...@@ -85,14 +91,6 @@ def _test_processing_correctness( ...@@ -85,14 +91,6 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
} }
tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
mm_data = { mm_data = {
k: k:
...@@ -115,6 +113,49 @@ def _test_processing_correctness( ...@@ -115,6 +113,49 @@ def _test_processing_correctness(
elif len(mm_data[k]) == 1: elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0] mm_data[k] = mm_data[k][0]
if isinstance(tokenizer, MistralTokenizer):
_test_processing_correctness_mistral(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
else:
_test_processing_correctness_hf(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
def _test_processing_correctness_hf(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
token_prompt = tokenizer.encode(prompt)
baseline_result = baseline_processor.apply( baseline_result = baseline_processor.apply(
prompt, prompt,
mm_data=mm_data, mm_data=mm_data,
...@@ -126,32 +167,77 @@ def _test_processing_correctness( ...@@ -126,32 +167,77 @@ def _test_processing_correctness(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _drop_mm_kwargs_keys( assert _inputs_equal(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( baseline_result,
cached_result, ignore_mm_keys), ( cached_result,
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
baseline_tokenized_result = baseline_processor.apply( baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt, **tokenizer_encode_kwargs), token_prompt,
mm_data=mm_data, mm_data=mm_data,
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _drop_mm_kwargs_keys( assert _inputs_equal(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( baseline_result,
baseline_tokenized_result, ignore_mm_keys), ( baseline_tokenized_result,
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
def _test_processing_correctness_mistral(
model_config: ModelConfig,
tokenizer: MistralTokenizer,
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
images = [images]
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=prompt),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
token_prompt = res.tokens
# Mistral chat outputs tokens directly, rather than text prompts
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply( cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt, **tokenizer_encode_kwargs), token_prompt,
mm_data=mm_data, mm_data=mm_data,
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _drop_mm_kwargs_keys( assert _inputs_equal(
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys( baseline_tokenized_result,
cached_tokenized_result, ignore_mm_keys), ( cached_tokenized_result,
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
# yapf: disable # yapf: disable
...@@ -173,6 +259,7 @@ def _test_processing_correctness( ...@@ -173,6 +259,7 @@ def _test_processing_correctness(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3", "TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6", "openbmb/MiniCPM-V-2_6",
...@@ -241,8 +328,19 @@ def test_processing_correctness_phi3v( ...@@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
) )
def _drop_mm_kwargs_keys(result: dict, def _inputs_equal(
ignore_mm_keys: Optional[list[str]] = None) -> dict: a: MultiModalInputs,
b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
b, ignore_mm_keys)
def _drop_mm_kwargs_keys(
result: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs']. """Drop specified keys from result['mm_kwargs'].
This is mainly to avoid doing exact match of audio_features in ultravox. This is mainly to avoid doing exact match of audio_features in ultravox.
......
...@@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict): ...@@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_embeds)` Shape: `(batch_size, num_images, num_embeds)`
""" """
num_crops: Union[torch.Tensor, list[torch.Tensor]] num_patches: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`""" """Shape: `(batch_size, num_images)`"""
...@@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor( ...@@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2], image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"] ) for pixel_value in processed_outputs["pixel_values"]
] ]
num_crops = torch.tensor([(ncols + 1) * nrows num_patches = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes]) for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to # Each image may result to masks of different sizes, so we need to
# flatten the list and later use `num_crops` to get per-image masks. # later use `num_patches` to get per-image masks.
embed_is_patch = torch.tensor( embed_is_patch = [
flatten_2d_lists([([True] * ncols + [False]) * nrows torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes])) for ncols, nrows in tile_sizes
processed_outputs["num_crops"] = num_crops ]
processed_outputs["num_patches"] = num_patches
processed_outputs["embed_is_patch"] = embed_is_patch processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["feat_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
...@@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor( ...@@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor(
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict( return dict(
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral": if self.config.vision_config.model_type == "pixtral":
feat_is_patch = kwargs.pop("feat_is_patch")
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch") embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)): if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. " raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}") f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops") num_patches = kwargs.pop("num_patches")
if not isinstance(num_crops, (torch.Tensor, list)): if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_crops)}") f"Got type: {type(num_patches)}")
return PixtralHFImagePixelInputs( return PixtralHFImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch, embed_is_patch=embed_is_patch,
num_crops=num_crops, num_patches=num_patches,
) )
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
...@@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
vision_tower: Union[CLIPVisionModel, SiglipVisionModel, vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel], PixtralHFVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]], pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since # NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower # this is already done inside the vision tower
image_features = vision_tower(pixel_values) image_features = vision_tower(pixel_values)
def select_features(leaf: torch.Tensor):
return self._select_image_features( return self._select_image_features(
image_features, leaf,
strategy=self.config.vision_feature_select_strategy, strategy=self.config.vision_feature_select_strategy,
) )
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
def _process_image_pixels( def _process_image_pixels(
self, self,
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs], inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
...@@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _get_mm_embeds( def _get_mm_embeds(
self, self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d) features: torch.Tensor, # Shape: (num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) num_patches: torch.Tensor, # Shape: (num_images,)
num_crops: torch.Tensor, # Shape: (num_images,) embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,) ) -> tuple[torch.Tensor, ...]:
) -> list[torch.Tensor]:
"""Scatter the patch features into a contiguous tensor that corresponds """Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor. to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment. Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
""" """
# Insert columns of nan values according to `embed_is_patch`. This work
# Insert columns of nan values according to `feat_is_patch`. This work
# ideally should be done in `_process_image_input`, but # ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to # `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here. # put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is # FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`. # deprecated. Merge this function with `Molmo._get_mm_embeds`.
feat_is_patch = feat_is_patch.view(-1) num_patches_per_image: list[int] = num_patches.tolist()
embed_is_patch = embed_is_patch.view(-1)
expanded_embedding = torch.full(
(sum(num_crops), *features.shape[1:]),
torch.nan,
dtype=features.dtype).to(features.device)
expanded_embedding[feat_is_patch] = features
num_crops_per_image = num_crops.tolist() embeds_flat = features.new_full(
feats_per_image = expanded_embedding.split(num_crops_per_image) (sum(num_patches_per_image), *features.shape[1:]),
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) fill_value=torch.nan,
)
embed_dim = expanded_embedding.shape[-1] embeds_flat[embed_is_patch.view(-1)] = features
num_embeds = embed_is_patch.shape[0]
embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
return embeds_in_batch return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
...@@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# The path is used for pixtral (V0 only) and llava (V0/V1) # The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings return vision_embeddings
nested_emb = [ return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip( self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"], vision_embeddings,
image_input["num_crops"], image_input["embed_is_patch"]) image_input["num_patches"],
] image_input["embed_is_patch"],
return flatten_2d_lists(nested_emb) ))
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) )
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, cast(NestedTensors, input_ids,
patch_embeddings), inputs_embeds,
self.config.image_token_index) cast(NestedTensors, patch_embeddings),
self.config.image_token_index,
)
return inputs_embeds return inputs_embeds
def forward( def forward(
......
...@@ -1585,15 +1585,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1585,15 +1585,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
nested_embeds = [ return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip( self._get_mm_embeds(*args) for args in zip(
image_features, image_features,
image_input["feat_is_patch"], image_input["feat_is_patch"],
image_input["num_crops"], image_input["num_crops"],
image_input["embed_is_patch"], image_input["embed_is_patch"],
) ))
]
return flatten_2d_lists(nested_embeds)
def get_input_embeddings( def get_input_embeddings(
self, self,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, from typing import Literal, Optional, Set, Tuple, TypedDict, Union
TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -17,7 +16,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, ...@@ -17,7 +16,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptReplacement, PromptInsertion, PromptUpdate,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -144,7 +143,7 @@ class PaliGemmaMultiModalProcessor( ...@@ -144,7 +143,7 @@ class PaliGemmaMultiModalProcessor(
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import cached_property from functools import cached_property
from typing import List, Optional, Set, Tuple, Union from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk from mistral_common.protocol.instruct.messages import ImageChunk
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image from PIL import Image
from transformers import PixtralVisionConfig from transformers import PixtralVisionConfig, TensorType
from transformers.image_utils import ImageInput
from transformers.models.pixtral.image_processing_pixtral import ( from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens) _num_image_tokens as _get_pixtral_hf_num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import ( from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.jsontree import JSONTree, json_map_leaves
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -31,13 +33,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -31,13 +33,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
from vllm.sequence import IntermediateTensors, SequenceData MultiModalDataItems)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (init_vllm_registered_model, maybe_prefix, from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
...@@ -48,132 +57,275 @@ except ImportError: ...@@ -48,132 +57,275 @@ except ImportError:
USE_XFORMERS_OPS = False USE_XFORMERS_OPS = False
def get_max_pixtral_image_tokens(ctx: InputContext): class PixtralImagePixelInputs(TypedDict):
tokenizer = cached_tokenizer_from_config(ctx.model_config) type: Literal["pixel_values"]
mm_encoder = tokenizer.instruct.mm_encoder
image_config = mm_encoder.mm_config if hasattr( images: Union[torch.Tensor, list[torch.Tensor]]
mm_encoder, "mm_config") else mm_encoder.image_config """
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
max_image_size = image_config.max_image_size The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
image_patch_size = image_config.image_patch_size """
return ((max_image_size // image_patch_size)**2) embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
num_patches: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class PixtralProcessorAdapter:
"""
Provide a HF-compatible interface for
:class:`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
@property
def image_processor(self) -> ImageEncoder:
image_encoder = self.tokenizer.instruct.mm_encoder
assert isinstance(image_encoder, ImageEncoder)
return image_encoder
@cached_property
def image_break_id(self) -> int:
return self.image_processor.special_ids.img_break
@cached_property
def image_token_id(self) -> int:
return self.image_processor.special_ids.img
@cached_property
def image_end_id(self) -> int:
return self.image_processor.special_ids.img_end
@cached_property
def image_size(self) -> int:
return self.image_processor.mm_config.max_image_size
@cached_property
def patch_size(self) -> int:
return self.image_processor.mm_config.image_patch_size
def __call__(
self,
text: Optional[Union[TextInput, list[TextInput]]] = None,
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if not images:
input_ids = self.tokenizer(text).input_ids
return {"input_ids": torch.tensor(input_ids)}
# Allow dummy text, which is used for profiling as well as token inputs
if any(len(t) > 0 for t in text):
raise ValueError(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, image_token_id = self.image_token_id
mm_counts: Mapping[str, int]):
tokenizer = cached_tokenizer_from_config(ctx.model_config)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder images_processed = list[torch.Tensor]()
image_token_id = mm_encoder.special_ids.img images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
images_num_patches = list[int]()
mm_config = ctx.get_mm_config() for image in images:
num_images = mm_config.get_limit_per_prompt("image") image_inputs = self.image_processor(ImageChunk(image=image))
# dummy size image_processed = torch.tensor(image_inputs.image)
size = 256 image_tokens = torch.tensor(image_inputs.tokens)
image = Image.new("RGB", (size, size), color=0)
encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image)) images_processed.append(image_processed)
image_feature_size = len(encoding.tokens) images_tokens.append(image_tokens)
num_image_tokens = image_feature_size * num_images images_embed_is_patch.append(image_tokens == image_token_id)
seq_data = SequenceData.from_prompt_token_counts( images_num_patches.append(len(image_tokens))
(image_token_id, num_image_tokens),
(0, seq_len - num_image_tokens), return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
"num_patches": torch.tensor(images_num_patches),
}
class PixtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
if not isinstance(tokenizer, MistralTokenizer):
raise ValueError("This model requires `--tokenizer-mode mistral`")
return tokenizer
def get_hf_processor(self) -> PixtralProcessorAdapter:
return PixtralProcessorAdapter(self.get_tokenizer())
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_vision_config(
self,
processor: Optional[PixtralProcessorAdapter] = None,
):
if processor is None:
processor = self.get_hf_processor()
return PixtralVisionConfig(
image_size=processor.image_size,
patch_size=processor.patch_size,
)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[PixtralProcessorAdapter] = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height)))
return (ncols + 1) * nrows
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor
max_image_size = image_processor.mm_config.max_image_size
return ImageSize(width=max_image_size, height=max_image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
) )
mm_data = {"image": num_images * [image]}
mm_placeholders = { class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image": "image":
consecutive_placeholder_ranges(num_items=num_images, self._get_dummy_images(width=target_width,
item_size=image_feature_size) height=target_height,
num_images=num_images)
} }
return DummyData(seq_data, mm_data, mm_placeholders)
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
def input_mapper_for_pixtral(ctx: InputContext,
data: object) -> MultiModalKwargs:
"""Maps the input data to its MultiModalKwargs (if any).
Args: class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
ctx: Context of the loaded model. ):
data: data potentially containing PIL images to be processed
and mapped to `images`.
Returns: def _get_mm_fields_config(
MultiModalKwargs containing the stacked normalized images tensor or self,
image embeddings. hf_inputs: Mapping[str, NestedTensors],
""" hf_processor_mm_kwargs: Mapping[str, object],
tokenizer = cached_tokenizer_from_config(ctx.model_config) ) -> Mapping[str, MultiModalFieldConfig]:
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
)
data_list = data if isinstance(data, list) else [data] def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
images = [] image_break_id = processor.image_break_id
image_tokens_list = [] image_token_id = processor.image_token_id
for image_data in data_list: image_end_id = processor.image_end_id
image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
images.append(image)
image_tokens_list.append(encoding.tokens)
image_tokens = torch.tensor([
token_id for image_tokens in image_tokens_list
for token_id in image_tokens
])
return MultiModalKwargs({"images": images, "image_tokens": image_tokens})
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): ncols, nrows = processor.image_processor._image_to_num_tokens(
multi_modal_data = inputs.get("multi_modal_data") Image.new("RGB", (image_size.width, image_size.height)))
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
prompt_token_ids = inputs.get("prompt_token_ids") tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
prompt = inputs.get("prompt") tokens[-1] = image_end_id
tokenizer = cached_tokenizer_from_config(ctx.model_config)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder return tokens
image_token_id = mm_encoder.special_ids.img
image_break_id = mm_encoder.special_ids.img_break
image_end_id = mm_encoder.special_ids.img_end
if image_token_id not in inputs['prompt_token_ids']: return [
raise ValueError( PromptReplacement(
f"You've passed {inputs=} without {image_token_id=}" modality="image",
" Make sure to process your input via mistral_common's" target="", # Never match the prompt (see below note)
" tokenizer or pass a chat completion request. For more" replacement=get_replacement,
" For more info, see: " ),
"https://github.com/vllm-project/vllm/issues/8411.") ]
# Get precise tracking of placeholder positions def _cached_apply_hf_processor(
placeholder_ranges = [] self,
curr_offset = -1 prompt: Union[str, list[int]],
curr_length = 0 mm_data_items: MultiModalDataItems,
for i in range(len(prompt_token_ids)): hf_processor_mm_kwargs: Mapping[str, object],
if prompt_token_ids[i] in (image_token_id, image_break_id): ) -> tuple[list[int], MultiModalKwargs, bool]:
if curr_offset < 0: prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor(
curr_offset = i prompt=prompt,
curr_length += 1 mm_data_items=mm_data_items,
elif prompt_token_ids[i] == image_end_id: hf_processor_mm_kwargs=hf_processor_mm_kwargs,
curr_length += 1 )
placeholder_ranges.append(
PlaceholderRange(offset=curr_offset, length=curr_length)) # NOTE: The tokens are already inserted by the chat template
curr_offset = -1 return prompt_ids, mm_kwargs, True
curr_length = 0
else:
pass
return token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) info=PixtralProcessingInfo,
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) dummy_inputs=PixtralDummyInputsBuilder)
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
...@@ -191,13 +343,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -191,13 +343,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if key in dataclass_fields if key in dataclass_fields
} }
if not ("image_break_token_id" in vision_args
and "image_end_token_id" in vision_args):
raise ValueError(
"'image_break_token_id' and 'image_end_token_id' not found "
"in the vision_encoder arguments. Please download the latest "
"version of 'params.json' from the model repository.")
self.vision_args = VisionEncoderArgs(**vision_args) self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM # init MistralForCausalLM
...@@ -221,36 +366,92 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -221,36 +366,92 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler() return get_sampler()
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PixtralImagePixelInputs]:
images = kwargs.pop("images", None)
if images is None:
return None
if not isinstance(images, (torch.Tensor, list)):
raise ValueError("Incorrect type of images. "
f"Got type: {type(images)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
num_patches=num_patches,
)
def _process_image_input(
self,
image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
images = image_input["images"]
image_features = self.vision_encoder(images)
feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]
image_embeds = self.vision_language_adapter(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
num_patches_per_image: list[int] = num_patches.tolist()
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input, image_tokens = self._parse_and_validate_image_input( image_input = self._parse_and_validate_image_input(**kwargs)
**kwargs)
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
# NOTE: We patch the outputs of the vision encoder with embeddings
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
image_embeds = self.language_model.get_input_embeddings(image_tokens)
image_token_mask = image_tokens == self.vision_args.image_token_id
image_embeds[image_token_mask] = vision_embeddings
# NOTE: Image embeddings are split into separate tensors for each image if kwargs.get("v0_path", False):
# by the indices of `[IMG_END]` token. return image_features
image_end_mask = image_tokens == self.vision_args.image_end_token_id
split_indices = torch.where(image_end_mask)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
# If the last split index is the last index in image_tokens, we return flatten_2d_lists(
# ignore it to avoid empty split tensor self._get_mm_embeds(*args) for args in zip(
if split_indices[-1] == len(image_tokens): image_features,
split_indices = split_indices[:-1] image_input["num_patches"],
image_input["embed_is_patch"],
image_embeds = image_embeds.tensor_split(split_indices.cpu()) ))
return image_embeds
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -259,12 +460,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -259,12 +460,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [ input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
self.vision_args.image_token_id, self.vision_args.image_token_id,
self.vision_args.image_break_token_id, )
self.vision_args.image_end_token_id,
])
return inputs_embeds return inputs_embeds
def forward( def forward(
...@@ -275,14 +481,14 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -275,14 +481,14 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for pixtral. """Run forward pass for pixtral."""
"""
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)
...@@ -295,47 +501,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -295,47 +501,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return hidden_states return hidden_states
def _parse_and_validate_image_input(
self,
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None,
image_tokens: Optional[torch.Tensor] = None,
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
if images is None:
return None, None
if isinstance(images, torch.Tensor):
# if passed as batch take all images
N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
flatten_images = []
for imgs_per_req in images:
imgs_per_req = [
imgs_per_req[i] for i in range(imgs_per_req.size(0))
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
flatten_images.extend(imgs_per_req)
images = flatten_images
if isinstance(image_tokens, torch.Tensor):
# image_tokens are batched
image_tokens = image_tokens.flatten()
elif isinstance(image_tokens, list):
# image_tokens are of different lengths thus passed as a list
image_tokens = torch.cat(image_tokens)
assert image_tokens.dim() == 1
return images, image_tokens
def _process_image_input(self,
image_input: List[torch.Tensor]) -> torch.Tensor:
return self.vision_language_adapter(self.vision_encoder(image_input))
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -400,8 +565,6 @@ class VisionEncoderArgs: ...@@ -400,8 +565,6 @@ class VisionEncoderArgs:
num_attention_heads: int num_attention_heads: int
rope_theta: float # for rope-2D rope_theta: float # for rope-2D
image_token_id: int image_token_id: int
image_break_token_id: int
image_end_token_id: int
adapter_bias: bool = True adapter_bias: bool = True
...@@ -637,9 +800,13 @@ class VisionTransformer(nn.Module): ...@@ -637,9 +800,13 @@ class VisionTransformer(nn.Module):
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
] ]
patch_embeds = [
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
]
embed_sizes = [p.shape[1] for p in patch_embeds]
# flatten to a single sequence # flatten to a single sequence
patch_embeds = torch.cat( patch_embeds = torch.cat(patch_embeds, dim=1)
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds) patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings # positional embeddings
...@@ -655,8 +822,8 @@ class VisionTransformer(nn.Module): ...@@ -655,8 +822,8 @@ class VisionTransformer(nn.Module):
"with the Mistral format") "with the Mistral format")
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# remove batch dimension of the single sequence # squeeze dim 0 and split into separate tensors for each image
return out.squeeze(0) return torch.split(out.squeeze(0), embed_sizes)
class VisionLanguageAdapter(nn.Module): class VisionLanguageAdapter(nn.Module):
...@@ -978,9 +1145,9 @@ class PixtralHFVisionModel(nn.Module): ...@@ -978,9 +1145,9 @@ class PixtralHFVisionModel(nn.Module):
def forward( def forward(
self, self,
pixel_values: List[torch.Tensor], pixel_values: list[torch.Tensor],
feature_sample_layers: Optional[list[int]] = None, feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor: ) -> tuple[torch.Tensor, ...]:
""" """
Args: Args:
pixel_values: Each image to be processed will be a separate tensor pixel_values: Each image to be processed will be a separate tensor
...@@ -1039,8 +1206,7 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1039,8 +1206,7 @@ class PixtralHFVisionModel(nn.Module):
self.config.num_hidden_layers) self.config.num_hidden_layers)
# squeeze dim 0 and split into separate tensors for each image # squeeze dim 0 and split into separate tensors for each image
out = torch.split(torch.squeeze(out), embed_sizes) return torch.split(out.squeeze(0), embed_sizes)
return out
# (TODO) Add prefix argument for filtering out weights to be loaded # (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
......
...@@ -77,7 +77,9 @@ class PromptIndexTargets: ...@@ -77,7 +77,9 @@ class PromptIndexTargets:
else: else:
if isinstance(prefix, str): if isinstance(prefix, str):
# Make both `list[int]` # Make both `list[int]`
prefix = encode_tokens(tokenizer, prefix) prefix = encode_tokens(tokenizer,
prefix,
add_special_tokens=False)
match_idx = len(prefix) match_idx = len(prefix)
return match_idx if prompt[:match_idx] == prefix else None return match_idx if prompt[:match_idx] == prefix else None
...@@ -318,7 +320,7 @@ def _cached_encode( ...@@ -318,7 +320,7 @@ def _cached_encode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
text: str, text: str,
*, *,
add_special_tokens: bool = False, add_special_tokens: Optional[bool] = None,
) -> list[int]: ) -> list[int]:
return encode_tokens(tokenizer, return encode_tokens(tokenizer,
text, text,
...@@ -330,7 +332,7 @@ def _cached_decode( ...@@ -330,7 +332,7 @@ def _cached_decode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
token_ids: tuple[int, ...], token_ids: tuple[int, ...],
*, *,
skip_special_tokens: bool = False, skip_special_tokens: Optional[bool] = None,
) -> str: ) -> str:
return decode_tokens(tokenizer, return decode_tokens(tokenizer,
list(token_ids), list(token_ids),
...@@ -395,7 +397,9 @@ class _BoundPromptSequence: ...@@ -395,7 +397,9 @@ class _BoundPromptSequence:
def token_ids(self) -> list[int]: def token_ids(self) -> list[int]:
if self._token_ids is None: if self._token_ids is None:
assert self._text is not None assert self._text is not None
self._token_ids = _cached_encode(self.tokenizer, self._text) self._token_ids = _cached_encode(self.tokenizer,
self._text,
add_special_tokens=False)
return self._token_ids return self._token_ids
...@@ -1046,7 +1050,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1046,7 +1050,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptUpdate]: ) -> Sequence[PromptUpdate]:
""" """
Given the original multi-modal items for this modality Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform. and HF-processed data, output the updates to perform.
......
...@@ -34,13 +34,20 @@ def decode_tokens( ...@@ -34,13 +34,20 @@ def decode_tokens(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
token_ids: list[int], token_ids: list[int],
*, *,
skip_special_tokens: bool = False, skip_special_tokens: Optional[bool] = None,
) -> str: ) -> str:
""" """
Backend-agnostic equivalent of HF's Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. :code:`tokenizer.decode(token_ids, ...)`.
:code:`skip_special_tokens=None` means to use the backend's default
settings.
""" """
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) if skip_special_tokens is not None:
return tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
return tokenizer.decode(token_ids)
def encode_tokens( def encode_tokens(
...@@ -51,10 +58,14 @@ def encode_tokens( ...@@ -51,10 +58,14 @@ def encode_tokens(
) -> list[int]: ) -> list[int]:
""" """
Backend-agnostic equivalent of HF's Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`. :code:`tokenizer.encode(text, ...)`.
:code:`add_special_tokens=None` means to use the backend's default
settings.
""" """
if add_special_tokens is not None: if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens) return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text) return tokenizer.encode(text)
......
...@@ -845,7 +845,7 @@ def is_list_of( ...@@ -845,7 +845,7 @@ def is_list_of(
assert_never(check) assert_never(check)
def flatten_2d_lists(lists: list[list[T]]) -> list[T]: def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list.""" """Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist] return [item for sublist in lists for item in sublist]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment