Unverified Commit 61c6a5a7 authored by Rémi Delacourt's avatar Rémi Delacourt Committed by GitHub
Browse files

[VLM] Merged multi-modal processor for Pixtral (#12211)


Signed-off-by: default avatarremi <remi@mistral.ai>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 74bc397b
...@@ -43,12 +43,18 @@ from vllm.sampling_params import SamplingParams ...@@ -43,12 +43,18 @@ from vllm.sampling_params import SamplingParams
# python demo.py advanced # python demo.py advanced
def run_simple_demo(): def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409" model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=8192) sampling_params = SamplingParams(max_tokens=8192)
# Lower max_num_seqs or max_model_len on low-VRAM GPUs. # Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
llm = LLM(model=model_name, tokenizer_mode="mistral") llm = LLM(
model=model_name,
tokenizer_mode="mistral",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = "Describe this image in one sentence." prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300" image_url = "https://picsum.photos/id/237/200/300"
...@@ -76,7 +82,7 @@ def run_simple_demo(): ...@@ -76,7 +82,7 @@ def run_simple_demo():
print(outputs[0].outputs[0].text) print(outputs[0].outputs[0].text)
def run_advanced_demo(): def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409" model_name = "mistralai/Pixtral-12B-2409"
max_img_per_msg = 5 max_img_per_msg = 5
max_tokens_per_img = 4096 max_tokens_per_img = 4096
...@@ -87,6 +93,7 @@ def run_advanced_demo(): ...@@ -87,6 +93,7 @@ def run_advanced_demo():
tokenizer_mode="mistral", tokenizer_mode="mistral",
limit_mm_per_prompt={"image": max_img_per_msg}, limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img, max_model_len=max_img_per_msg * max_tokens_per_img,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
prompt = "Describe the following image." prompt = "Describe the following image."
...@@ -153,14 +160,19 @@ def main(): ...@@ -153,14 +160,19 @@ def main():
help="Specify the demo mode: 'simple' or 'advanced'", help="Specify the demo mode: 'simple' or 'advanced'",
) )
parser.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, disables caching of multi-modal preprocessor/mapper.')
args = parser.parse_args() args = parser.parse_args()
if args.mode == "simple": if args.mode == "simple":
print("Running simple demo...") print("Running simple demo...")
run_simple_demo() run_simple_demo(args)
elif args.mode == "advanced": elif args.mode == "advanced":
print("Running advanced demo...") print("Running advanced demo...")
run_advanced_demo() run_advanced_demo(args)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -2,17 +2,23 @@ ...@@ -2,17 +2,23 @@
import copy import copy
from functools import partial from functools import partial
from typing import Optional from typing import Optional, Union
import numpy as np import numpy as np
import pytest import pytest
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.processing import ProcessingCache from vllm.multimodal.inputs import MultiModalInputs
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from ....multimodal.utils import random_audio, random_image, random_video from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS from ...registry import HF_EXAMPLE_MODELS
...@@ -85,14 +91,6 @@ def _test_processing_correctness( ...@@ -85,14 +91,6 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
} }
tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
mm_data = { mm_data = {
k: k:
...@@ -115,43 +113,131 @@ def _test_processing_correctness( ...@@ -115,43 +113,131 @@ def _test_processing_correctness(
elif len(mm_data[k]) == 1: elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0] mm_data[k] = mm_data[k][0]
baseline_result = baseline_processor.apply( if isinstance(tokenizer, MistralTokenizer):
prompt, _test_processing_correctness_mistral(
mm_data=mm_data, model_config,
hf_processor_mm_kwargs={}, tokenizer,
) prompt,
cached_result = cached_processor.apply( mm_data,
prompt, baseline_processor,
mm_data=mm_data, cached_processor,
hf_processor_mm_kwargs={}, batch_idx,
) ignore_mm_keys=ignore_mm_keys,
)
assert _drop_mm_kwargs_keys( else:
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( _test_processing_correctness_hf(
cached_result, ignore_mm_keys), ( model_config,
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") tokenizer,
prompt,
baseline_tokenized_result = baseline_processor.apply( mm_data,
tokenizer.encode(prompt, **tokenizer_encode_kwargs), baseline_processor,
mm_data=mm_data, cached_processor,
hf_processor_mm_kwargs={}, batch_idx,
) ignore_mm_keys=ignore_mm_keys,
)
assert _drop_mm_kwargs_keys(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
baseline_tokenized_result, ignore_mm_keys), ( def _test_processing_correctness_hf(
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
cached_tokenized_result = cached_processor.apply( prompt: str,
tokenizer.encode(prompt, **tokenizer_encode_kwargs), mm_data: MultiModalDataDict,
mm_data=mm_data, baseline_processor: BaseMultiModalProcessor,
hf_processor_mm_kwargs={}, cached_processor: BaseMultiModalProcessor,
) batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
assert _drop_mm_kwargs_keys( ):
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys( if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
cached_tokenized_result, ignore_mm_keys), ( # For some multimodal models, tokenizer will always add bos_token
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") # at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
token_prompt = tokenizer.encode(prompt)
baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
baseline_result,
cached_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
def _test_processing_correctness_mistral(
model_config: ModelConfig,
tokenizer: MistralTokenizer,
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
images = [images]
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=prompt),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
token_prompt = res.tokens
# Mistral chat outputs tokens directly, rather than text prompts
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
baseline_tokenized_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
# yapf: disable # yapf: disable
...@@ -173,6 +259,7 @@ def _test_processing_correctness( ...@@ -173,6 +259,7 @@ def _test_processing_correctness(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3", "TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6", "openbmb/MiniCPM-V-2_6",
...@@ -241,8 +328,19 @@ def test_processing_correctness_phi3v( ...@@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
) )
def _drop_mm_kwargs_keys(result: dict, def _inputs_equal(
ignore_mm_keys: Optional[list[str]] = None) -> dict: a: MultiModalInputs,
b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
b, ignore_mm_keys)
def _drop_mm_kwargs_keys(
result: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs']. """Drop specified keys from result['mm_kwargs'].
This is mainly to avoid doing exact match of audio_features in ultravox. This is mainly to avoid doing exact match of audio_features in ultravox.
......
...@@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict): ...@@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_embeds)` Shape: `(batch_size, num_images, num_embeds)`
""" """
num_crops: Union[torch.Tensor, list[torch.Tensor]] num_patches: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`""" """Shape: `(batch_size, num_images)`"""
...@@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor( ...@@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2], image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"] ) for pixel_value in processed_outputs["pixel_values"]
] ]
num_crops = torch.tensor([(ncols + 1) * nrows num_patches = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes]) for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to # Each image may result to masks of different sizes, so we need to
# flatten the list and later use `num_crops` to get per-image masks. # later use `num_patches` to get per-image masks.
embed_is_patch = torch.tensor( embed_is_patch = [
flatten_2d_lists([([True] * ncols + [False]) * nrows torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes])) for ncols, nrows in tile_sizes
processed_outputs["num_crops"] = num_crops ]
processed_outputs["num_patches"] = num_patches
processed_outputs["embed_is_patch"] = embed_is_patch processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["feat_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
...@@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor( ...@@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor(
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict( return dict(
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral": if self.config.vision_config.model_type == "pixtral":
feat_is_patch = kwargs.pop("feat_is_patch")
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch") embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)): if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. " raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}") f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops") num_patches = kwargs.pop("num_patches")
if not isinstance(num_crops, (torch.Tensor, list)): if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_crops)}") f"Got type: {type(num_patches)}")
return PixtralHFImagePixelInputs( return PixtralHFImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch, embed_is_patch=embed_is_patch,
num_crops=num_crops, num_patches=num_patches,
) )
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
...@@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
vision_tower: Union[CLIPVisionModel, SiglipVisionModel, vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel], PixtralHFVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]], pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since # NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower # this is already done inside the vision tower
image_features = vision_tower(pixel_values) image_features = vision_tower(pixel_values)
return self._select_image_features( def select_features(leaf: torch.Tensor):
image_features, return self._select_image_features(
strategy=self.config.vision_feature_select_strategy, leaf,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
) )
def _process_image_pixels( def _process_image_pixels(
self, self,
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs], inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
...@@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _get_mm_embeds( def _get_mm_embeds(
self, self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d) features: torch.Tensor, # Shape: (num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) num_patches: torch.Tensor, # Shape: (num_images,)
num_crops: torch.Tensor, # Shape: (num_images,) embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,) ) -> tuple[torch.Tensor, ...]:
) -> list[torch.Tensor]:
"""Scatter the patch features into a contiguous tensor that corresponds """Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor. to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment. Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
""" """
# Insert columns of nan values according to `embed_is_patch`. This work
# Insert columns of nan values according to `feat_is_patch`. This work
# ideally should be done in `_process_image_input`, but # ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to # `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here. # put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is # FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`. # deprecated. Merge this function with `Molmo._get_mm_embeds`.
feat_is_patch = feat_is_patch.view(-1) num_patches_per_image: list[int] = num_patches.tolist()
embed_is_patch = embed_is_patch.view(-1)
expanded_embedding = torch.full(
(sum(num_crops), *features.shape[1:]),
torch.nan,
dtype=features.dtype).to(features.device)
expanded_embedding[feat_is_patch] = features
num_crops_per_image = num_crops.tolist() embeds_flat = features.new_full(
feats_per_image = expanded_embedding.split(num_crops_per_image) (sum(num_patches_per_image), *features.shape[1:]),
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) fill_value=torch.nan,
)
embed_dim = expanded_embedding.shape[-1] embeds_flat[embed_is_patch.view(-1)] = features
num_embeds = embed_is_patch.shape[0]
embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
return embeds_in_batch return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
...@@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# The path is used for pixtral (V0 only) and llava (V0/V1) # The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings return vision_embeddings
nested_emb = [ return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip( self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"], vision_embeddings,
image_input["num_crops"], image_input["embed_is_patch"]) image_input["num_patches"],
] image_input["embed_is_patch"],
return flatten_2d_lists(nested_emb) ))
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) )
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, cast(NestedTensors, input_ids,
patch_embeddings), inputs_embeds,
self.config.image_token_index) cast(NestedTensors, patch_embeddings),
self.config.image_token_index,
)
return inputs_embeds return inputs_embeds
def forward( def forward(
......
...@@ -1585,15 +1585,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1585,15 +1585,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
nested_embeds = [ return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip( self._get_mm_embeds(*args) for args in zip(
image_features, image_features,
image_input["feat_is_patch"], image_input["feat_is_patch"],
image_input["num_crops"], image_input["num_crops"],
image_input["embed_is_patch"], image_input["embed_is_patch"],
) ))
]
return flatten_2d_lists(nested_embeds)
def get_input_embeddings( def get_input_embeddings(
self, self,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, from typing import Literal, Optional, Set, Tuple, TypedDict, Union
TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -17,7 +16,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, ...@@ -17,7 +16,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptReplacement, PromptInsertion, PromptUpdate,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -144,7 +143,7 @@ class PaliGemmaMultiModalProcessor( ...@@ -144,7 +143,7 @@ class PaliGemmaMultiModalProcessor(
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
......
This diff is collapsed.
...@@ -77,7 +77,9 @@ class PromptIndexTargets: ...@@ -77,7 +77,9 @@ class PromptIndexTargets:
else: else:
if isinstance(prefix, str): if isinstance(prefix, str):
# Make both `list[int]` # Make both `list[int]`
prefix = encode_tokens(tokenizer, prefix) prefix = encode_tokens(tokenizer,
prefix,
add_special_tokens=False)
match_idx = len(prefix) match_idx = len(prefix)
return match_idx if prompt[:match_idx] == prefix else None return match_idx if prompt[:match_idx] == prefix else None
...@@ -318,7 +320,7 @@ def _cached_encode( ...@@ -318,7 +320,7 @@ def _cached_encode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
text: str, text: str,
*, *,
add_special_tokens: bool = False, add_special_tokens: Optional[bool] = None,
) -> list[int]: ) -> list[int]:
return encode_tokens(tokenizer, return encode_tokens(tokenizer,
text, text,
...@@ -330,7 +332,7 @@ def _cached_decode( ...@@ -330,7 +332,7 @@ def _cached_decode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
token_ids: tuple[int, ...], token_ids: tuple[int, ...],
*, *,
skip_special_tokens: bool = False, skip_special_tokens: Optional[bool] = None,
) -> str: ) -> str:
return decode_tokens(tokenizer, return decode_tokens(tokenizer,
list(token_ids), list(token_ids),
...@@ -395,7 +397,9 @@ class _BoundPromptSequence: ...@@ -395,7 +397,9 @@ class _BoundPromptSequence:
def token_ids(self) -> list[int]: def token_ids(self) -> list[int]:
if self._token_ids is None: if self._token_ids is None:
assert self._text is not None assert self._text is not None
self._token_ids = _cached_encode(self.tokenizer, self._text) self._token_ids = _cached_encode(self.tokenizer,
self._text,
add_special_tokens=False)
return self._token_ids return self._token_ids
...@@ -1046,7 +1050,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1046,7 +1050,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptUpdate]: ) -> Sequence[PromptUpdate]:
""" """
Given the original multi-modal items for this modality Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform. and HF-processed data, output the updates to perform.
......
...@@ -34,13 +34,20 @@ def decode_tokens( ...@@ -34,13 +34,20 @@ def decode_tokens(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
token_ids: list[int], token_ids: list[int],
*, *,
skip_special_tokens: bool = False, skip_special_tokens: Optional[bool] = None,
) -> str: ) -> str:
""" """
Backend-agnostic equivalent of HF's Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. :code:`tokenizer.decode(token_ids, ...)`.
:code:`skip_special_tokens=None` means to use the backend's default
settings.
""" """
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) if skip_special_tokens is not None:
return tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
return tokenizer.decode(token_ids)
def encode_tokens( def encode_tokens(
...@@ -51,10 +58,14 @@ def encode_tokens( ...@@ -51,10 +58,14 @@ def encode_tokens(
) -> list[int]: ) -> list[int]:
""" """
Backend-agnostic equivalent of HF's Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`. :code:`tokenizer.encode(text, ...)`.
:code:`add_special_tokens=None` means to use the backend's default
settings.
""" """
if add_special_tokens is not None: if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens) return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text) return tokenizer.encode(text)
......
...@@ -845,7 +845,7 @@ def is_list_of( ...@@ -845,7 +845,7 @@ def is_list_of(
assert_never(check) assert_never(check)
def flatten_2d_lists(lists: list[list[T]]) -> list[T]: def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list.""" """Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist] return [item for sublist in lists for item in sublist]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment