Unverified Commit 502c41a8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Use helper function to run MM processors with token inputs (where applicable) (#38018)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 52069012
...@@ -1210,6 +1210,17 @@ class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): ...@@ -1210,6 +1210,17 @@ class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Override to use the text path instead of token path to use the
# video-specific logic in processing_keye.py
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -371,6 +371,17 @@ class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo): ...@@ -371,6 +371,17 @@ class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo):
class KeyeVL1_5MultiModalProcessor(BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]): class KeyeVL1_5MultiModalProcessor(BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Override to use the text path instead of token path to use the
# video-specific logic in processing_keye.py
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -215,6 +215,17 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo]) ...@@ -215,6 +215,17 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo])
grid_thws=MultiModalFieldConfig.batched("vision_chunk"), grid_thws=MultiModalFieldConfig.batched("vision_chunk"),
) )
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Override to use the text path instead of token path because vision chunk
# is not considered
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -12,7 +12,7 @@ import torch.nn.functional as F ...@@ -12,7 +12,7 @@ import torch.nn.functional as F
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from transformers import PixtralVisionConfig from transformers import BatchFeature, PixtralVisionConfig
from transformers.models.pixtral.image_processing_pixtral import ( from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens, _num_image_tokens as _get_pixtral_hf_num_image_tokens,
) )
...@@ -62,6 +62,7 @@ from vllm.sequence import IntermediateTensors ...@@ -62,6 +62,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor
from vllm.utils.collection_utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
...@@ -213,6 +214,27 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]) ...@@ -213,6 +214,27 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(images=MultiModalFieldConfig.batched("image")) return dict(images=MultiModalFieldConfig.batched("image"))
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
# Avoid padding issue
tok_kwargs={**tok_kwargs, "return_tensors": None},
)
# Missing batch dimension
if is_list_of(outputs["input_ids"], int):
outputs["input_ids"] = [outputs["input_ids"]]
return outputs
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -929,6 +929,17 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): ...@@ -929,6 +929,17 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
second_per_grid_ts=MultiModalFieldConfig.batched("video"), second_per_grid_ts=MultiModalFieldConfig.batched("video"),
) )
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Override to use the text path instead of token path to use the
# video-specific logic in processing_qwen2_5_vl.py
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors ...@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor
from vllm.utils.collection_utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix from .utils import init_vllm_registered_model, maybe_prefix
...@@ -208,7 +209,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) ...@@ -208,7 +209,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
) -> None: ) -> None:
# mistral_common's tokenizer's does not follow HF's placeholder norms # mistral_common's tokenizer's does not follow HF's placeholder norms
# skip validation here # skip validation here
... pass
def _call_hf_processor( def _call_hf_processor(
self, self,
...@@ -224,13 +225,20 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) ...@@ -224,13 +225,20 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
# MistralCommonVoxtralProcessor accepts "audio" # MistralCommonVoxtralProcessor accepts "audio"
mm_data["audio"] = audios mm_data["audio"] = audios
return super()._call_hf_processor( outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs, # Avoid padding issue
tok_kwargs={**tok_kwargs, "return_tensors": None},
) )
# Missing batch dimension
if is_list_of(outputs["input_ids"], int):
outputs["input_ids"] = [outputs["input_ids"]]
return outputs
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -196,7 +196,7 @@ class InputProcessingContext: ...@@ -196,7 +196,7 @@ class InputProcessingContext:
tokenizer = self.tokenizer tokenizer = self.tokenizer
if is_mistral_tokenizer(tokenizer): if is_mistral_tokenizer(tokenizer):
tokenizer = tokenizer.transformers_tokenizer tokenizer = tokenizer.transformers_tokenizer # type: ignore[union-attr]
merged_kwargs = self.get_merged_mm_kwargs(kwargs) merged_kwargs = self.get_merged_mm_kwargs(kwargs)
merged_kwargs.pop("tokenizer", None) merged_kwargs.pop("tokenizer", None)
...@@ -263,9 +263,10 @@ class InputProcessingContext: ...@@ -263,9 +263,10 @@ class InputProcessingContext:
requires_kw_only=False, requires_kw_only=False,
allow_var_kwargs=True, allow_var_kwargs=True,
) )
allowed_kwargs.setdefault("return_tensors", "pt")
try: try:
output = hf_processor(**data, **allowed_kwargs, return_tensors="pt") output = hf_processor(**data, **allowed_kwargs)
except Exception as exc: except Exception as exc:
# See https://github.com/huggingface/tokenizers/issues/537 # See https://github.com/huggingface/tokenizers/issues/537
if ( if (
......
...@@ -5,8 +5,15 @@ from collections import defaultdict ...@@ -5,8 +5,15 @@ from collections import defaultdict
from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache, partial
from typing import TYPE_CHECKING, Generic, NamedTuple, Protocol, TypeAlias, cast from typing import (
TYPE_CHECKING,
Generic,
NamedTuple,
Protocol,
TypeAlias,
cast,
)
import regex as re import regex as re
import torch import torch
...@@ -21,6 +28,7 @@ from vllm.inputs import ( ...@@ -21,6 +28,7 @@ from vllm.inputs import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import call_hf_processor_mm_only
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from ..inputs import ( from ..inputs import (
...@@ -1150,7 +1158,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1150,7 +1158,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) )
processed_data.update(passthrough_data) processed_data.update(passthrough_data)
(prompt_ids,) = processed_data.pop("input_ids").tolist() input_ids = processed_data.pop("input_ids")
if not isinstance(input_ids, list):
input_ids = input_ids.tolist()
(prompt_ids,) = input_ids
is_update_applied = self._hf_processor_applies_updates( is_update_applied = self._hf_processor_applies_updates(
prompt_text=prompt_text, prompt_text=prompt_text,
...@@ -1213,16 +1225,35 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1213,16 +1225,35 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
[`DummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder] [`DummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
to go along with the multi-modal data. to go along with the multi-modal data.
""" """
mm_counts = mm_items.get_all_counts() # Custom logic based on text inputs
if type(self)._call_hf_processor != BaseMultiModalProcessor._call_hf_processor:
mm_counts = mm_items.get_all_counts()
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
_, mm_processed_data, _ = self._apply_hf_processor_text_mm( return mm_processed_data
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items, valid_mm_items = mm_items.select(
hf_processor_mm_kwargs=hf_processor_mm_kwargs, {k for k, c in mm_items.get_all_counts().items() if c > 0}
tokenization_kwargs=tokenization_kwargs,
) )
processor_data, passthrough_data = self._get_hf_mm_data(valid_mm_items)
processed_data = self.info.ctx.call_hf_processor(
partial(
call_hf_processor_mm_only,
self.info.get_hf_processor(**hf_processor_mm_kwargs),
),
processor_data,
dict(**hf_processor_mm_kwargs, **tokenization_kwargs),
)
processed_data.update(passthrough_data)
return mm_processed_data return processed_data
def _apply_hf_processor_main( def _apply_hf_processor_main(
self, self,
......
...@@ -11,12 +11,16 @@ from transformers import ( ...@@ -11,12 +11,16 @@ from transformers import (
AutoImageProcessor, AutoImageProcessor,
AutoProcessor, AutoProcessor,
AutoVideoProcessor, AutoVideoProcessor,
BatchFeature,
processing_utils, processing_utils,
) )
from transformers.audio_utils import AudioInput
from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor from transformers.video_processing_utils import BaseVideoProcessor
from transformers.video_utils import VideoInput
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -272,7 +276,6 @@ def get_processor_kwargs_keys( ...@@ -272,7 +276,6 @@ def get_processor_kwargs_keys(
"images_kwargs", "images_kwargs",
"videos_kwargs", "videos_kwargs",
"audio_kwargs", "audio_kwargs",
"common_kwargs",
} }
try: try:
...@@ -523,3 +526,43 @@ def cached_video_processor_from_config( ...@@ -523,3 +526,43 @@ def cached_video_processor_from_config(
processor_cls_overrides=processor_cls, # type: ignore[arg-type] processor_cls_overrides=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, AutoVideoProcessor, **kwargs), **_merge_mm_kwargs(model_config, AutoVideoProcessor, **kwargs),
) )
def call_hf_processor_mm_only(
processor: ProcessorMixin,
images: ImageInput | None = None,
videos: VideoInput | None = None,
audio: AudioInput | None = None,
**kwargs,
) -> BatchFeature:
output_kwargs = processor._merge_kwargs(
get_processor_kwargs_type(processor),
**kwargs,
)
if audio is not None and (
feature_extractor := getattr(processor, "feature_extractor", None)
):
audio_inputs = feature_extractor(audio, **output_kwargs["audio_kwargs"])
audio_inputs["feature_attention_mask"] = audio_inputs.pop("attention_mask")
else:
audio_inputs = {}
if images is not None and (
image_processor := getattr(processor, "image_processor", None)
):
images_inputs = image_processor(images=images, **output_kwargs["images_kwargs"])
else:
images_inputs = {}
if videos is not None and (
video_processor := getattr(processor, "video_processor", None)
):
videos_inputs = video_processor(videos=videos, **output_kwargs["videos_kwargs"])
else:
videos_inputs = {}
return BatchFeature(
data={**audio_inputs, **images_inputs, **videos_inputs},
tensor_type=kwargs.get("return_tensors"),
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math import math
from typing import Any from typing import Any, TypedDict
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import BatchFeature, ProcessorMixin, TensorType from transformers import BatchFeature, ProcessorMixin, TensorType
from typing_extensions import TypedDict, Unpack from transformers.processing_utils import ProcessingKwargs
from typing_extensions import Unpack
from vllm.tokenizers.hf import HfTokenizer from vllm.tokenizers.hf import HfTokenizer
...@@ -308,15 +307,22 @@ def process_vision_for_patches( ...@@ -308,15 +307,22 @@ def process_vision_for_patches(
return patches, dims_virtual return patches, dims_virtual
class IsaacImageProcessorKwargs(TypedDict, total=False): class IsaacImagesKwargs(TypedDict, total=False):
patch_size: int patch_size: int
max_num_patches: int max_num_patches: int
min_num_patches: int min_num_patches: int
pixel_shuffle_scale: int pixel_shuffle_scale: int
class IsaacProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg]
images_kwargs: IsaacImagesKwargs
_defaults = {
"text_kwargs": {"padding": False},
"images_kwargs": {},
}
class IsaacImageProcessor: class IsaacImageProcessor:
valid_kwargs = IsaacImageProcessorKwargs
model_input_names = ["pixel_values", "image_grid_thw"] model_input_names = ["pixel_values", "image_grid_thw"]
def __init__( def __init__(
...@@ -335,7 +341,7 @@ class IsaacImageProcessor: ...@@ -335,7 +341,7 @@ class IsaacImageProcessor:
self, self,
images: Image.Image | list[Image.Image], images: Image.Image | list[Image.Image],
return_tensors: str | TensorType | None = None, return_tensors: str | TensorType | None = None,
**kwargs: Unpack[IsaacImageProcessorKwargs], **kwargs: Unpack[IsaacImagesKwargs],
) -> BatchFeature: ) -> BatchFeature:
"""Preprocess images into format compatible with vLLM input processing.""" """Preprocess images into format compatible with vLLM input processing."""
if not isinstance(images, list): if not isinstance(images, list):
...@@ -349,10 +355,16 @@ class IsaacImageProcessor: ...@@ -349,10 +355,16 @@ class IsaacImageProcessor:
patches, dims_virtual = process_vision_for_patches( patches, dims_virtual = process_vision_for_patches(
image_tensor, image_tensor,
patch_size=self.patch_size, patch_size=kwargs.get("patch_size", self.patch_size),
max_num_patches=self.vision_max_num_patches, max_num_patches=kwargs.get(
min_num_patches=self.vision_min_num_patches, "max_num_patches", self.vision_max_num_patches
pixel_shuffle_scale=self.pixel_shuffle_scale, ),
min_num_patches=kwargs.get(
"min_num_patches", self.vision_min_num_patches
),
pixel_shuffle_scale=kwargs.get(
"pixel_shuffle_scale", self.pixel_shuffle_scale
),
) )
# Isaac packs a dummy temporal dim for images # Isaac packs a dummy temporal dim for images
...@@ -405,13 +417,17 @@ class IsaacProcessor(ProcessorMixin): ...@@ -405,13 +417,17 @@ class IsaacProcessor(ProcessorMixin):
text: str | list[str] | None = None, text: str | list[str] | None = None,
images: Image.Image | list[Image.Image] | None = None, images: Image.Image | list[Image.Image] | None = None,
return_tensors: str | TensorType | None = None, return_tensors: str | TensorType | None = None,
**kwargs, **kwargs: Unpack[IsaacProcessorKwargs], # type: ignore[misc]
) -> BatchFeature: ) -> BatchFeature:
output_kwargs = self._merge_kwargs(
IsaacProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None: if images is not None:
image_inputs = self.image_processor( image_inputs = self.image_processor(
images, images, **output_kwargs["images_kwargs"]
return_tensors=return_tensors,
**kwargs,
) )
image_grid_thw = image_inputs["image_grid_thw"] image_grid_thw = image_inputs["image_grid_thw"]
else: else:
...@@ -435,7 +451,7 @@ class IsaacProcessor(ProcessorMixin): ...@@ -435,7 +451,7 @@ class IsaacProcessor(ProcessorMixin):
index += 1 index += 1
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
text_inputs = self.tokenizer(text, return_tensors=return_tensors) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
else: else:
text_inputs = {} text_inputs = {}
......
...@@ -5,10 +5,7 @@ from mistral_common.protocol.instruct.chunk import ImageChunk ...@@ -5,10 +5,7 @@ from mistral_common.protocol.instruct.chunk import ImageChunk
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image from PIL import Image
from transformers import BatchFeature, ProcessorMixin, TensorType from transformers import BatchFeature, ProcessorMixin, TensorType
from transformers.audio_utils import AudioInput
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
...@@ -55,62 +52,16 @@ class MistralCommonPixtralProcessor(ProcessorMixin): ...@@ -55,62 +52,16 @@ class MistralCommonPixtralProcessor(ProcessorMixin):
def __init__(self, tokenizer: MistralTokenizer) -> None: def __init__(self, tokenizer: MistralTokenizer) -> None:
self.tokenizer = tokenizer.transformers_tokenizer self.tokenizer = tokenizer.transformers_tokenizer
# Back-compatibility for Transformers v4
if not hasattr(self.tokenizer, "init_kwargs"):
self.tokenizer.init_kwargs = {}
self.image_processor = MistralCommonImageProcessor( self.image_processor = MistralCommonImageProcessor(
tokenizer.instruct.mm_encoder tokenizer.instruct.mm_encoder
) )
self._image_special_ids = self.image_processor.mm_encoder.special_ids image_special_ids = self.image_processor.mm_encoder.special_ids
self.image_break_id = image_special_ids.img_break
@property self.image_token_id = image_special_ids.img
def image_break_id(self) -> int: self.image_end_id = image_special_ids.img_end
return self._image_special_ids.img_break
@property
def image_token_id(self) -> int:
return self._image_special_ids.img
@property
def image_end_id(self) -> int:
return self._image_special_ids.img_end
def __call__(
self,
images: ImageInput | None = None,
text: TextInput
| PreTokenizedInput
| list[TextInput]
| list[PreTokenizedInput]
| None = None,
videos: VideoInput | None = None,
audio: AudioInput | None = None,
**kwargs,
):
if images is None and text is None and videos is None and audio is None:
raise ValueError(
f"You need to provide at least one input to "
f"call {self.__class__.__name__}"
)
kwargs = self._merge_kwargs(
self.valid_processor_kwargs,
tokenizer_init_kwargs={},
**kwargs,
)
kwargs["text_kwargs"]["return_tensors"] = "pt"
kwargs["images_kwargs"]["return_tensors"] = None # Avoid padding issue
attribute_to_kwargs = {
"tokenizer": (text, "text_kwargs"),
"image_processor": (images, "images_kwargs"),
"video_processor": (videos, "videos_kwargs"),
"feature_extractor": (audio, "audio_kwargs"),
}
outputs = {}
for attribute_name in self.attributes:
attribute = getattr(self, attribute_name, None)
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
if input_data is not None and attribute is not None:
attribute_output = attribute(input_data, **kwargs[input_kwargs])
outputs.update(attribute_output)
return BatchFeature(outputs)
...@@ -8,9 +8,6 @@ import torch ...@@ -8,9 +8,6 @@ import torch
from mistral_common.tokens.tokenizers.audio import AudioEncoder from mistral_common.tokens.tokenizers.audio import AudioEncoder
from transformers import BatchFeature, ProcessorMixin, TensorType from transformers import BatchFeature, ProcessorMixin, TensorType
from transformers.audio_utils import AudioInput from transformers.audio_utils import AudioInput
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
...@@ -62,58 +59,15 @@ class MistralCommonVoxtralProcessor(ProcessorMixin): ...@@ -62,58 +59,15 @@ class MistralCommonVoxtralProcessor(ProcessorMixin):
def __init__(self, tokenizer: MistralTokenizer) -> None: def __init__(self, tokenizer: MistralTokenizer) -> None:
self.tokenizer = tokenizer.transformers_tokenizer self.tokenizer = tokenizer.transformers_tokenizer
# Back-compatibility for Transformers v4
if not hasattr(self.tokenizer, "init_kwargs"):
self.tokenizer.init_kwargs = {}
self.feature_extractor = MistralCommonFeatureExtractor( self.feature_extractor = MistralCommonFeatureExtractor(
tokenizer.instruct.audio_encoder tokenizer.instruct.audio_encoder
) )
self._audio_special_ids = self.feature_extractor.audio_encoder.special_ids audio_special_ids = self.feature_extractor.audio_encoder.special_ids
self.audio_token_id = audio_special_ids.audio
@property self.begin_audio_token_id = audio_special_ids.begin_audio
def audio_token_id(self) -> int:
return self._audio_special_ids.audio
@property
def begin_audio_token_id(self) -> int:
return self._audio_special_ids.begin_audio
def __call__(
self,
images: ImageInput | None = None,
text: TextInput
| PreTokenizedInput
| list[TextInput]
| list[PreTokenizedInput]
| None = None,
videos: VideoInput | None = None,
audio: AudioInput | None = None,
**kwargs,
):
if images is None and text is None and videos is None and audio is None:
raise ValueError(
f"You need to provide at least one input to "
f"call {self.__class__.__name__}"
)
kwargs = self._merge_kwargs(
self.valid_processor_kwargs,
tokenizer_init_kwargs={},
**kwargs,
)
kwargs["text_kwargs"]["return_tensors"] = "pt"
kwargs["audio_kwargs"]["return_tensors"] = None # Avoid padding issue
attribute_to_kwargs = {
"tokenizer": (text, "text_kwargs"),
"image_processor": (images, "images_kwargs"),
"video_processor": (videos, "videos_kwargs"),
"feature_extractor": (audio, "audio_kwargs"),
}
outputs = {}
for attribute_name in self.attributes:
attribute = getattr(self, attribute_name, None)
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
if input_data is not None and attribute is not None:
attribute_output = attribute(input_data, **kwargs[input_kwargs])
outputs.update(attribute_output)
return BatchFeature(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment