Unverified Commit d62856b9 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Move processors to `transformers_utils` (#35953)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bd2659a5
...@@ -13,11 +13,7 @@ import numpy as np ...@@ -13,11 +13,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from torchvision import transforms from transformers import BatchFeature
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
...@@ -50,7 +46,8 @@ from vllm.multimodal.processing import ( ...@@ -50,7 +46,8 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.processors.glm4v import GLM4VProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer
...@@ -386,81 +383,19 @@ class GLM4VModel(ChatGLMModel): ...@@ -386,81 +383,19 @@ class GLM4VModel(ChatGLMModel):
) )
class GLM4VProcessor:
"""
This model doesn't define its own HF processor,
so we implement our own one here.
"""
def __init__(
self,
config: ChatGLMConfig,
tokenizer: PreTrainedTokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
vision_config = config.vision_config
image_size = vision_config["image_size"]
self.image_transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
text_inputs = self.tokenizer(text)
if len(images) == 0:
image_inputs = {}
else:
pixel_values = [self.image_transform(image) for image in images]
image_inputs = {"pixel_values": torch.stack(pixel_values)}
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
class GLM4VProcessingInfo(BaseProcessingInfo): class GLM4VProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(ChatGLMConfig) return self.ctx.get_hf_config(ChatGLMConfig)
def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor: def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor:
config = self.get_hf_config()
vision_config = config.vision_config
image_size = vision_config["image_size"]
return self.ctx.init_processor( return self.ctx.init_processor(
GLM4VProcessor, GLM4VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(), tokenizer=self.get_tokenizer(),
**kwargs, **{**kwargs, "image_size": image_size},
) )
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property, partial from functools import partial
from itertools import islice from itertools import islice
from typing import Annotated from typing import Annotated
...@@ -13,9 +13,11 @@ import torch ...@@ -13,9 +13,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType from transformers import (
from transformers.image_utils import ImageInput BaseImageProcessor,
from transformers.tokenization_utils_base import TextInput BatchFeature,
PretrainedConfig,
)
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
...@@ -1017,117 +1019,28 @@ def select_tiling( ...@@ -1017,117 +1019,28 @@ def select_tiling(
return candidate_tilings[ix] return candidate_tilings[ix]
class MolmoProcessorWrapper: def _as_2tuple(x: int | tuple[int, int]) -> tuple[int, int]:
""" if isinstance(x, int):
Wraps `MolmoProcessor` so that it can be called directly. return x, x
The original definition can be found here:
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
"""
def __init__(self, processor: ProcessorMixin):
super().__init__()
self.processor = processor
@cached_property
def vocab(self) -> dict[str, int]:
return self.processor.tokenizer.vocab # type: ignore
@cached_property
def max_crops(self) -> int:
image_processor = self.processor.image_processor # type: ignore
max_crops = image_processor.max_crops
assert isinstance(max_crops, int)
return max_crops
@cached_property
def base_image_input_size(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
base_image_input_size = image_processor.base_image_input_size
if isinstance(base_image_input_size, int):
return base_image_input_size, base_image_input_size
return tuple(base_image_input_size)
@cached_property
def image_patch_size(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_patch_size = image_processor.image_patch_size
assert isinstance(image_patch_size, int)
return image_patch_size
@cached_property
def overlap_margins(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
left_margin, right_margin = image_processor.overlap_margins
assert isinstance(left_margin, int)
assert isinstance(right_margin, int)
return left_margin, right_margin
@cached_property
def image_token_length_w(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_token_length_w = image_processor.image_token_length_w
assert isinstance(image_token_length_w, int)
return image_token_length_w
@cached_property
def image_token_length_h(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_token_length_h = image_processor.image_token_length_h
assert isinstance(image_token_length_h, int)
return image_token_length_h
@property
def message_format(self) -> str | None:
return "role"
@property
def always_start_with_space(self) -> bool:
return True
@cached_property
def image_patch_id(self) -> int:
return self.vocab[IMAGE_PATCH_TOKEN]
@cached_property
def im_col_id(self) -> int:
return self.vocab[IM_COL_TOKEN]
@cached_property return x
def im_start_id(self) -> int:
return self.vocab[IM_START_TOKEN]
@cached_property
def im_end_id(self) -> int:
return self.vocab[IM_END_TOKEN]
@property class MolmoProcessingInfo(BaseProcessingInfo):
def pooling_size(self) -> int: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return POOLING_SIZE return {"image": None}
def select_tiling( def select_tiling(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]: ) -> tuple[int, int]:
max_crops = self.max_crops max_crops = image_processor.max_crops
left_margin, right_margin = self.overlap_margins left_margin, right_margin = image_processor.overlap_margins
base_image_input_size = self.base_image_input_size base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
base_image_input_d = self.image_patch_size base_image_input_d = image_processor.image_patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin) total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = base_image_input_size[0] // base_image_input_d crop_patches = base_image_input_size[0] // base_image_input_d
...@@ -1147,16 +1060,18 @@ class MolmoProcessorWrapper: ...@@ -1147,16 +1060,18 @@ class MolmoProcessorWrapper:
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]: ) -> tuple[int, int]:
left_margin, right_margin = self.overlap_margins left_margin, right_margin = image_processor.overlap_margins
base_image_input_size = self.base_image_input_size base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
base_image_input_d = self.image_patch_size base_image_input_d = image_processor.image_patch_size
pooling_size = self.pooling_size pooling_size = POOLING_SIZE
crop_patches = base_image_input_size[0] // base_image_input_d crop_patches = base_image_input_size[0] // base_image_input_d
tiling_w, tiling_h = self.select_tiling( tiling_w, tiling_h = self.select_tiling(
image_height=image_height, image_height=image_height,
image_width=image_width, image_width=image_width,
image_processor=image_processor,
) )
nrows, ncols = get_patches_grid_size( nrows, ncols = get_patches_grid_size(
...@@ -1170,70 +1085,22 @@ class MolmoProcessorWrapper: ...@@ -1170,70 +1085,22 @@ class MolmoProcessorWrapper:
return ncols, nrows return ncols, nrows
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
outputs = self.processor.process( # type: ignore
text, images, **kwargs
)
if images is None:
images = []
if not isinstance(images, list):
images = [images]
input_ids: torch.Tensor = outputs.pop("input_ids")
outputs["input_ids"] = input_ids.unsqueeze(0)
image_input_idx = outputs.pop("image_input_idx", None)
if image_input_idx is not None:
feat_is_patch = image_input_idx >= 0
tilings = [
self.select_tiling(
image_width=image.size[0],
image_height=image.size[1],
)
for image in images
]
# For each image: tiling_h * tiling_w + extra
num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch)
outputs["image_input_idx"] = image_input_idx
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id
return BatchFeature(outputs)
class MolmoProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
processor = self.ctx.get_hf_processor(**kwargs)
return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: MolmoProcessorWrapper, image_processor: BaseImageProcessor,
) -> int: ) -> int:
ncols, nrows = processor.get_patches_grid_size( ncols, nrows = self.get_patches_grid_size(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
image_processor=image_processor,
) )
pooling_size = processor.pooling_size pooling_size = POOLING_SIZE
image_token_length_w = processor.image_token_length_w image_token_length_w = image_processor.image_token_length_w
image_token_length_h = processor.image_token_length_h image_token_length_h = image_processor.image_token_length_h
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators # Calculate total tokens: 2 for start/end + (w+1)*h for column separators
extra = 2 + (image_token_length_w + 1) * image_token_length_h extra = 2 + (image_token_length_w + 1) * image_token_length_h
...@@ -1243,9 +1110,10 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1243,9 +1110,10 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
image_processor = processor.image_processor
tilings = get_candidate_tilings(processor.max_crops) tilings = get_candidate_tilings(image_processor.max_crops)
base_h, base_w = processor.base_image_input_size base_h, base_w = _as_2tuple(image_processor.base_image_input_size)
largest_feature_size, largest_feature_pinpoint = 0, None largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in tilings: for wr, hr in tilings:
...@@ -1254,7 +1122,7 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1254,7 +1122,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
feat_size = self.get_num_image_tokens( feat_size = self.get_num_image_tokens(
image_width=width, image_width=width,
image_height=height, image_height=height,
processor=processor, image_processor=image_processor,
) )
if feat_size > largest_feature_size: if feat_size > largest_feature_size:
largest_feature_size = feat_size largest_feature_size = feat_size
...@@ -1292,6 +1160,54 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): ...@@ -1292,6 +1160,54 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
hf_processor = self.info.get_hf_processor(**mm_kwargs)
processed_outputs = self.info.ctx.call_hf_processor(
hf_processor.process,
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
tokenizer = hf_processor.tokenizer
image_patch_id = tokenizer.vocab[IMAGE_PATCH_TOKEN]
image_processor = hf_processor.image_processor
input_ids: torch.Tensor = processed_outputs.pop("input_ids")
processed_outputs["input_ids"] = input_ids.unsqueeze(0)
if (images := mm_data.get("images")) is not None:
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
feat_is_patch = processed_outputs["image_input_idx"] >= 0
tilings = [
self.info.select_tiling(
image_width=image_size.width,
image_height=image_size.height,
image_processor=image_processor,
)
for image_size in image_sizes
]
# For each image: tiling_h * tiling_w + extra
num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch)
processed_outputs["num_crops"] = num_crops
processed_outputs["img_patch_id"] = image_patch_id
return processed_outputs
def _apply_hf_processor_tokens_only( def _apply_hf_processor_tokens_only(
self, self,
prompt_tokens: list[int], prompt_tokens: list[int],
...@@ -1301,18 +1217,19 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1301,18 +1217,19 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
# The chat template is already applied to the prompt tokens # The chat template is already applied to the prompt tokens
# Use message_format="none" to avoid applying it again # Use message_format="none" to avoid applying it again
# Prepend an empty space if `always_start_with_space` is True # Prepend an empty space if `always_start_with_space` is True
tokens = processor.processor.get_tokens_input( # type: ignore tokens = processor.get_tokens_input(
self.info.get_tokenizer().decode(prompt_tokens), self.info.get_tokenizer().decode(prompt_tokens),
message_format="none", message_format="none",
always_start_with_space=processor.always_start_with_space, always_start_with_space=True,
) )
# Prepend a BOS token id to the tokens # Prepend a BOS token id to the tokens
processed_data = self.info.ctx.call_hf_processor( processed_data = self.info.ctx.call_hf_processor(
processor, # type: ignore processor.process,
dict(tokens=tokens), dict(tokens=tokens),
) )
(prompt_ids,) = processed_data.pop("input_ids").tolist() prompt_ids = processed_data.pop("input_ids").tolist()
print(prompt_ids, len(prompt_ids))
return prompt_ids return prompt_ids
...@@ -1338,16 +1255,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1338,16 +1255,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
image_token_length_w = processor.image_token_length_w img_patch_id = vocab[IMAGE_PATCH_TOKEN]
image_token_length_h = processor.image_token_length_h img_col_id = vocab[IM_COL_TOKEN]
pooling_size = processor.pooling_size img_start_id = vocab[IM_START_TOKEN]
img_end_id = vocab[IM_END_TOKEN]
img_patch_id = processor.image_patch_id processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
img_col_id = processor.im_col_id image_processor = processor.image_processor
img_start_id = processor.im_start_id image_token_length_w = image_processor.image_token_length_w
img_end_id = processor.im_end_id image_token_length_h = image_processor.image_token_length_h
pooling_size = POOLING_SIZE
extra_row = [img_patch_id] * image_token_length_w + [img_col_id] extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id] extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id]
...@@ -1356,9 +1275,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1356,9 +1275,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
ncols, nrows = processor.get_patches_grid_size( ncols, nrows = self.info.get_patches_grid_size(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
image_processor=image_processor,
) )
joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id] joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id]
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import cached_property
from typing import Annotated, Literal from typing import Annotated, Literal
import torch import torch
...@@ -13,10 +12,7 @@ import torch.nn.functional as F ...@@ -13,10 +12,7 @@ import torch.nn.functional as F
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from transformers import PixtralVisionConfig
from PIL import Image
from transformers import BatchFeature, PixtralVisionConfig, TensorType
from transformers.image_utils import ImageInput
from transformers.models.pixtral.image_processing_pixtral import ( from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens, _num_image_tokens as _get_pixtral_hf_num_image_tokens,
) )
...@@ -25,7 +21,6 @@ from transformers.models.pixtral.modeling_pixtral import ( ...@@ -25,7 +21,6 @@ from transformers.models.pixtral.modeling_pixtral import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
position_ids_in_meshgrid, position_ids_in_meshgrid,
) )
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
...@@ -66,6 +61,7 @@ from vllm.platforms import current_platform ...@@ -66,6 +61,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
...@@ -121,93 +117,6 @@ class PixtralImagePixelInputs(TensorSchema): ...@@ -121,93 +117,6 @@ class PixtralImagePixelInputs(TensorSchema):
] ]
class PixtralProcessorAdapter:
"""
Provide a HF-compatible interface for
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
@property
def image_processor(self) -> ImageEncoder:
image_encoder = self.tokenizer.instruct.mm_encoder
assert isinstance(image_encoder, ImageEncoder)
return image_encoder
@cached_property
def image_break_id(self) -> int:
return self.image_processor.special_ids.img_break
@cached_property
def image_token_id(self) -> int:
return self.image_processor.special_ids.img
@cached_property
def image_end_id(self) -> int:
return self.image_processor.special_ids.img_end
@cached_property
def image_size(self) -> int:
return self.image_processor.mm_config.max_image_size
@cached_property
def patch_size(self) -> int:
return self.image_processor.mm_config.image_patch_size
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if not images:
input_ids = self.tokenizer(text).input_ids
return {"input_ids": torch.tensor(input_ids)}
# Allow dummy text, which is used for profiling as well as token inputs
if any(len(t) > 0 for t in text):
raise ValueError(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
image_processed = torch.tensor(image_inputs.image)
image_tokens = torch.tensor(image_inputs.tokens)
images_processed.append(image_processed)
images_tokens.append(image_tokens)
return BatchFeature(
{
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
}
)
class PixtralProcessingInfo(BaseProcessingInfo): class PixtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer: def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config) tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
...@@ -216,28 +125,19 @@ class PixtralProcessingInfo(BaseProcessingInfo): ...@@ -216,28 +125,19 @@ class PixtralProcessingInfo(BaseProcessingInfo):
return tokenizer return tokenizer
def get_hf_processor(self) -> PixtralProcessorAdapter: def get_hf_processor(self, **kwargs) -> MistralCommonPixtralProcessor:
return PixtralProcessorAdapter(self.get_tokenizer()) return self.ctx.init_processor(
MistralCommonPixtralProcessor,
tokenizer=self.get_tokenizer(),
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None} return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: PixtralProcessorAdapter,
) -> int:
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height))
)
return ncols * nrows
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor image_processor = self.get_hf_processor().image_processor
max_image_size = image_processor.mm_config.max_image_size max_image_size = image_processor.mm_encoder.mm_config.max_image_size
return ImageSize(width=max_image_size, height=max_image_size) return ImageSize(width=max_image_size, height=max_image_size)
...@@ -321,8 +221,9 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]) ...@@ -321,8 +221,9 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
ncols, nrows = processor.image_processor._image_to_num_tokens( _, nrows, ncols = processor.image_processor.get_number_of_image_patches(
Image.new("RGB", (image_size.width, image_size.height)) image_size.height,
image_size.width,
) )
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
......
...@@ -14,11 +14,7 @@ from typing import Annotated, Literal, TypeAlias ...@@ -14,11 +14,7 @@ from typing import Annotated, Literal, TypeAlias
import regex as re import regex as re
import torch import torch
from torch import nn from torch import nn
from torchvision import transforms from transformers import BatchFeature
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
...@@ -48,6 +44,7 @@ from vllm.multimodal.processing import ( ...@@ -48,6 +44,7 @@ from vllm.multimodal.processing import (
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.qwen_vl import QwenVLProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
...@@ -434,96 +431,16 @@ class QwenVLModel(QWenModel): ...@@ -434,96 +431,16 @@ class QwenVLModel(QWenModel):
) )
class QwenVLProcessor: class QwenVLProcessingInfo(BaseProcessingInfo):
""" def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
This model doesn't define its own HF processor, config = self.get_hf_config()
so we implement our own one here.
We call the wrapped tokenizer to automatically insert image pad tokens:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245
The image processor is defined here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
"""
def __init__(
self,
config: PretrainedConfig,
tokenizer: PreTrainedTokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
vision_config = config.visual vision_config = config.visual
image_size = vision_config["image_size"] image_size = vision_config["image_size"]
self.image_transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
@property
def image_start_tag(self) -> str:
return self.tokenizer.image_start_tag # type: ignore
@property
def image_end_tag(self) -> str:
return self.tokenizer.image_end_tag # type: ignore
@property
def image_pad_tag(self) -> str:
return self.tokenizer.image_pad_tag # type: ignore
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
text_inputs = self.tokenizer(text)
if len(images) == 0:
image_inputs = {}
else:
pixel_values = [self.image_transform(image) for image in images]
image_inputs = {"pixel_values": torch.stack(pixel_values)}
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
class QwenVLProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
return self.ctx.init_processor( return self.ctx.init_processor(
QwenVLProcessor, QwenVLProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(), tokenizer=self.get_tokenizer(),
**kwargs, **{**kwargs, "image_size": image_size},
) )
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
......
...@@ -3,25 +3,19 @@ ...@@ -3,25 +3,19 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import partial
from math import ceil
from typing import Literal, cast from typing import Literal, cast
import numpy as np import numpy as np
import regex as re import regex as re
import torch import torch
import torch.nn as nn import torch.nn as nn
from mistral_common.audio import mel_filter_bank from mistral_common.audio import Audio, mel_filter_bank
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import ( from transformers import BatchFeature, WhisperConfig
Audio,
AudioEncoder,
)
from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
...@@ -62,6 +56,7 @@ from vllm.multimodal.processing.processor import ( ...@@ -62,6 +56,7 @@ from vllm.multimodal.processing.processor import (
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix from .utils import init_vllm_registered_model, maybe_prefix
...@@ -81,98 +76,6 @@ ISO639_1_SUPPORTED_LANGS = { ...@@ -81,98 +76,6 @@ ISO639_1_SUPPORTED_LANGS = {
} }
class VoxtralProcessorAdapter:
"""
Provide a HF-compatible interface for
:class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
@cached_property
def _audio_processor(self) -> AudioEncoder:
audio_encoder = self.tokenizer.instruct.audio_encoder
assert isinstance(audio_encoder, AudioEncoder)
return audio_encoder
@cached_property
def audio_token_id(self) -> int:
return self._audio_processor.special_ids.audio
@cached_property
def begin_audio_token_id(self) -> int:
return self._audio_processor.special_ids.begin_audio
@cached_property
def sampling_rate(self) -> int:
return self._audio_processor.audio_config.sampling_rate
@cached_property
def frame_rate(self) -> float:
return self._audio_processor.audio_config.frame_rate
def get_num_audio_tokens(
self,
audio_length: int,
) -> int:
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
audios: np.ndarray | list[np.ndarray] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if audios is None:
audios = []
if not isinstance(audios, list):
audios = [audios]
if not audios:
input_ids = self.tokenizer(text).input_ids
return {"input_ids": torch.tensor(input_ids)}
# Allow dummy text, which is used for profiling as well as token inputs
if any(len(t) > 0 for t in text):
raise ValueError(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
audios_tokens = list[torch.Tensor]()
audios_processed = list[torch.Tensor]()
for audio in audios:
assert isinstance(audio, np.ndarray)
assert audio.ndim == 1
if not self._audio_processor.audio_config.is_streaming:
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id
] * self.get_num_audio_tokens(len(audio))
audios_tokens.append(torch.tensor(audio_tokens))
audios_processed.append(torch.tensor(audio))
return BatchFeature(
{
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
"audio_arrays": audios_processed,
}
)
class VoxtralProcessingInfo(BaseProcessingInfo): class VoxtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer: def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config) tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
...@@ -181,12 +84,18 @@ class VoxtralProcessingInfo(BaseProcessingInfo): ...@@ -181,12 +84,18 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
return tokenizer return tokenizer
def get_hf_processor(self) -> VoxtralProcessorAdapter: def get_hf_processor(self, **kwargs) -> MistralCommonVoxtralProcessor:
return VoxtralProcessorAdapter(self.get_tokenizer()) return self.ctx.init_processor(
MistralCommonVoxtralProcessor,
tokenizer=self.get_tokenizer(),
**kwargs,
)
def get_data_parser(self): def get_data_parser(self):
feature_extractor = self.get_hf_processor().feature_extractor
return MultiModalDataParser( return MultiModalDataParser(
target_sr=self.get_hf_processor().sampling_rate, target_sr=feature_extractor.sampling_rate,
target_channels=1, target_channels=1,
expected_hidden_size=self._get_expected_hidden_size(), expected_hidden_size=self._get_expected_hidden_size(),
) )
...@@ -205,9 +114,10 @@ class VoxtralProcessingInfo(BaseProcessingInfo): ...@@ -205,9 +114,10 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
return self.ctx.model_config.max_model_len return self.ctx.model_config.max_model_len
def get_max_audio_array_len(self) -> int: def get_max_audio_array_len(self) -> int:
processor = self.get_hf_processor() feature_extractor = self.get_hf_processor().feature_extractor
return self.get_max_audio_tokens() * int( return self.get_max_audio_tokens() * int(
processor.sampling_rate // processor.frame_rate feature_extractor.sampling_rate // feature_extractor.frame_rate
) )
...@@ -242,6 +152,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): ...@@ -242,6 +152,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
mm_options: Mapping[str, BaseDummyOptions], mm_options: Mapping[str, BaseDummyOptions],
) -> ProcessorInputs: ) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
feature_extractor = self.info.get_hf_processor().feature_extractor
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
...@@ -252,7 +163,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): ...@@ -252,7 +163,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
for audio in dummy_audios: for audio in dummy_audios:
audio_item = Audio( audio_item = Audio(
audio_array=audio, audio_array=audio,
sampling_rate=self.info.get_hf_processor().sampling_rate, sampling_rate=feature_extractor.sampling_rate,
format=format, format=format,
) )
chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item)) chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
...@@ -292,33 +203,26 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) ...@@ -292,33 +203,26 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
# skip validation here # skip validation here
... ...
def _apply_hf_processor_mm_only( def _call_hf_processor(
self, self,
mm_items: MultiModalDataItems, prompt: str,
hf_processor_mm_kwargs: Mapping[str, object], mm_data: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) mm_data = dict(mm_data)
processor_data, passthrough_data = self._get_hf_mm_data(mm_items) audios = mm_data.pop("audios", [])
audios = processor_data.get("audios", [])
if not isinstance(audios, list): if audios:
audios = [audios] # MistralCommonVoxtralProcessor accepts "audio"
mm_data["audio"] = audios
audio_config = processor._audio_processor.audio_config
audio_tensors: list[torch.Tensor] = [] return super()._call_hf_processor(
for audio in audios: prompt=prompt,
audio = np.asarray(audio, dtype=np.float32).ravel() mm_data=mm_data,
if not audio_config.is_streaming: mm_kwargs=mm_kwargs,
audio = processor._audio_processor.pad( tok_kwargs=tok_kwargs,
audio, )
processor.sampling_rate,
audio_config.is_streaming,
)
audio_tensors.append(torch.tensor(audio))
result = BatchFeature({"audio_arrays": audio_tensors} if audio_tensors else {})
result.update(passthrough_data)
return result
def _get_prompt_updates( def _get_prompt_updates(
self, self,
...@@ -327,6 +231,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) ...@@ -327,6 +231,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
out_mm_kwargs: MultiModalKwargsItems, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
feature_extractor = processor.feature_extractor
audio_id = processor.audio_token_id audio_id = processor.audio_token_id
out_mm_data = out_mm_kwargs.require_data() out_mm_data = out_mm_kwargs.require_data()
...@@ -348,7 +253,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) ...@@ -348,7 +253,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
audios = mm_items.get_items("audio", AudioProcessorItems) audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx) audio_len = audios.get_audio_length(item_idx)
nb_audio_tokens = processor.get_num_audio_tokens(audio_len) nb_audio_tokens = feature_extractor.get_num_audio_tokens(audio_len)
return [audio_id] * nb_audio_tokens return [audio_id] * nb_audio_tokens
...@@ -560,8 +465,8 @@ class VoxtralForConditionalGeneration( ...@@ -560,8 +465,8 @@ class VoxtralForConditionalGeneration(
This is used for estimating the amount of processing for this audio. This is used for estimating the amount of processing for this audio.
""" """
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
adapter = VoxtralProcessorAdapter(tokenizer) adapter = MistralCommonVoxtralProcessor(tokenizer)
return adapter.get_num_audio_tokens( return adapter.feature_extractor.get_num_audio_tokens(
int(audio_duration_s * stt_config.sample_rate) int(audio_duration_s * stt_config.sample_rate)
) )
......
...@@ -8,12 +8,13 @@ from typing import Literal ...@@ -8,12 +8,13 @@ from typing import Literal
import numpy as np import numpy as np
import torch import torch
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import RawAudio from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import ( from mistral_common.protocol.transcription.request import (
StreamingMode, StreamingMode,
TranscriptionRequest, TranscriptionRequest,
) )
from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig from mistral_common.tokens.tokenizers.audio import AudioConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Mapping from collections.abc import Callable, Mapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cached_property from functools import cached_property
...@@ -241,13 +241,13 @@ class InputProcessingContext: ...@@ -241,13 +241,13 @@ class InputProcessingContext:
def call_hf_processor( def call_hf_processor(
self, self,
hf_processor: ProcessorMixin, hf_processor: Callable[..., BatchFeature] | ProcessorMixin,
data: Mapping[str, object], data: Mapping[str, object],
kwargs: Mapping[str, object] = {}, kwargs: Mapping[str, object] = {},
*, *,
num_tries: int = 1, num_tries: int = 1,
max_tries: int = 5, max_tries: int = 5,
) -> BatchFeature | JSONTree: ) -> BatchFeature:
""" """
Call `hf_processor` on the prompt `data` Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`. (text, image, audio...) with configurable options `kwargs`.
...@@ -300,7 +300,7 @@ class InputProcessingContext: ...@@ -300,7 +300,7 @@ class InputProcessingContext:
if isinstance(output, BatchFeature): if isinstance(output, BatchFeature):
output_ = self._postprocess_output(output.data) output_ = self._postprocess_output(output.data)
return BatchFeature(output_) return BatchFeature(output_) # type: ignore
logger.warning_once( logger.warning_once(
"%s did not return `BatchFeature`. " "%s did not return `BatchFeature`. "
...@@ -309,7 +309,7 @@ class InputProcessingContext: ...@@ -309,7 +309,7 @@ class InputProcessingContext:
type(hf_processor).__name__, type(hf_processor).__name__,
) )
return self._postprocess_output(output) return self._postprocess_output(output) # type: ignore
class BaseProcessingInfo: class BaseProcessingInfo:
......
...@@ -241,12 +241,13 @@ def get_processor_kwargs_type( ...@@ -241,12 +241,13 @@ def get_processor_kwargs_type(
call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None
# if the processor has explicit kwargs annotation, use it # if the processor has explicit kwargs annotation, use it
if call_kwargs_annotations not in (None, inspect._empty): if call_kwargs_annotations not in (None, inspect._empty): # noqa: SIM102
# get_type_hints will parse all type annotations at runtime, # get_type_hints will parse all type annotations at runtime,
# and if an annotation refers to a type or # and if an annotation refers to a type or
# name that hasn’t been imported or defined, it will raise an error. # name that hasn’t been imported or defined, it will raise an error.
# So we use __annotations__ to get the raw annotations directly. # So we use __annotations__ to get the raw annotations directly.
return get_args(call_kwargs_annotations)[0] if anno_args := get_args(call_kwargs_annotations):
return anno_args[0]
# otherwise, try to get from ProcessorKwargs # otherwise, try to get from ProcessorKwargs
module_name = type(processor).__module__ module_name = type(processor).__module__
...@@ -266,7 +267,13 @@ def get_processor_kwargs_keys( ...@@ -266,7 +267,13 @@ def get_processor_kwargs_keys(
kwargs_cls: type[processing_utils.ProcessingKwargs], kwargs_cls: type[processing_utils.ProcessingKwargs],
) -> set[str]: ) -> set[str]:
dynamic_kwargs: set[str] = set() dynamic_kwargs: set[str] = set()
modality_kwargs = {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"} modality_kwargs = {
"text_kwargs",
"images_kwargs",
"videos_kwargs",
"audio_kwargs",
"common_kwargs",
}
try: try:
# get kwargs annotations in processor # get kwargs annotations in processor
......
...@@ -15,10 +15,14 @@ _CLASS_TO_MODULE: dict[str, str] = { ...@@ -15,10 +15,14 @@ _CLASS_TO_MODULE: dict[str, str] = {
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2", "DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2", "FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
"FunASRProcessor": "vllm.transformers_utils.processors.funasr", "FunASRProcessor": "vllm.transformers_utils.processors.funasr",
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
"HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl", "HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl",
"HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image", "HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image",
"MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral",
"MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral",
"OvisProcessor": "vllm.transformers_utils.processors.ovis", "OvisProcessor": "vllm.transformers_utils.processors.ovis",
"Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5", "Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5",
"QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl",
"Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr", "Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr",
} }
...@@ -28,10 +32,14 @@ __all__ = [ ...@@ -28,10 +32,14 @@ __all__ = [
"DeepseekVLV2Processor", "DeepseekVLV2Processor",
"FireRedASR2Processor", "FireRedASR2Processor",
"FunASRProcessor", "FunASRProcessor",
"GLM4VProcessor",
"HunYuanVLProcessor", "HunYuanVLProcessor",
"HunYuanVLImageProcessor", "HunYuanVLImageProcessor",
"MistralCommonPixtralProcessor",
"MistralCommonVoxtralProcessor",
"OvisProcessor", "OvisProcessor",
"Ovis2_5Processor", "Ovis2_5Processor",
"QwenVLProcessor",
"Qwen3ASRProcessor", "Qwen3ASRProcessor",
] ]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import PreTrainedTokenizer
from transformers.image_processing_utils_fast import BaseImageProcessorFast
from transformers.image_utils import PILImageResampling
from transformers.processing_utils import ProcessorMixin
class GLM4VImageProcessorFast(BaseImageProcessorFast):
"""
Port of https://huggingface.co/zai-org/glm-4v-9b/blob/main/tokenization_chatglm.py#L177
to HF Transformers.
"""
resample = PILImageResampling.BICUBIC
image_mean = [0.48145466, 0.4578275, 0.40821073]
image_std = [0.26862954, 0.26130258, 0.27577711]
size = {"height": 1120, "width": 1120}
do_resize = True
do_rescale = True
do_normalize = True
class GLM4VProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
def __init__(
self,
tokenizer: PreTrainedTokenizer,
image_size: int,
) -> None:
self.tokenizer = tokenizer
self.image_processor = GLM4VImageProcessorFast(
size={"width": image_size, "height": image_size}
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from mistral_common.protocol.instruct.chunk import ImageChunk
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import BatchFeature, ProcessorMixin, TensorType
from transformers.audio_utils import AudioInput
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput
from vllm.tokenizers.mistral import MistralTokenizer
class MistralCommonImageProcessor:
"""
Provide a HF-compatible interface for
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
def __init__(self, mm_encoder: ImageEncoder) -> None:
self.mm_encoder = mm_encoder
def __call__(
self,
images: ImageInput,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
images_lst = [images] if not isinstance(images, list) else images
images_processed = list[torch.Tensor]()
for image in images_lst:
image_inputs = self.mm_encoder(ImageChunk(image=image))
image_processed = torch.tensor(image_inputs.image)
images_processed.append(image_processed)
return BatchFeature({"images": images_processed}, tensor_type=return_tensors)
def get_number_of_image_patches(
self,
height: int,
width: int,
) -> tuple[int, int, int]:
image = Image.new("RGB", (width, height))
ncols, nrows = self.mm_encoder._image_to_num_tokens(image)
return ncols * nrows, nrows, ncols
class MistralCommonPixtralProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
def __init__(self, tokenizer: MistralTokenizer) -> None:
self.tokenizer = tokenizer.transformers_tokenizer
self.image_processor = MistralCommonImageProcessor(
tokenizer.instruct.mm_encoder
)
self._image_special_ids = self.image_processor.mm_encoder.special_ids
@property
def image_break_id(self) -> int:
return self._image_special_ids.img_break
@property
def image_token_id(self) -> int:
return self._image_special_ids.img
@property
def image_end_id(self) -> int:
return self._image_special_ids.img_end
def __call__(
self,
images: ImageInput | None = None,
text: TextInput
| PreTokenizedInput
| list[TextInput]
| list[PreTokenizedInput]
| None = None,
videos: VideoInput | None = None,
audio: AudioInput | None = None,
**kwargs,
):
if images is None and text is None and videos is None and audio is None:
raise ValueError(
f"You need to provide at least one input to "
f"call {self.__class__.__name__}"
)
kwargs = self._merge_kwargs(
self.valid_processor_kwargs,
tokenizer_init_kwargs={},
**kwargs,
)
kwargs["text_kwargs"]["return_tensors"] = "pt"
kwargs["images_kwargs"]["return_tensors"] = None # Avoid padding issue
attribute_to_kwargs = {
"tokenizer": (text, "text_kwargs"),
"image_processor": (images, "images_kwargs"),
"video_processor": (videos, "videos_kwargs"),
"feature_extractor": (audio, "audio_kwargs"),
}
outputs = {}
for attribute_name in self.attributes:
attribute = getattr(self, attribute_name, None)
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
if input_data is not None and attribute is not None:
attribute_output = attribute(input_data, **kwargs[input_kwargs])
outputs.update(attribute_output)
return BatchFeature(outputs)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers.image_processing_utils_fast import BaseImageProcessorFast
from transformers.image_utils import PILImageResampling
from transformers.processing_utils import ProcessorMixin
from vllm.tokenizers.qwen_vl import QwenVLTokenizer
class QwenVLImageProcessorFast(BaseImageProcessorFast):
"""
Port of https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
to HF Transformers.
"""
resample = PILImageResampling.BICUBIC
image_mean = [0.48145466, 0.4578275, 0.40821073]
image_std = [0.26862954, 0.26130258, 0.27577711]
size = {"height": 448, "width": 448}
do_resize = True
do_rescale = True
do_normalize = True
class QwenVLProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
def __init__(
self,
tokenizer: QwenVLTokenizer,
image_size: int,
) -> None:
self.tokenizer = tokenizer
self.image_processor = QwenVLImageProcessorFast(
size={"width": image_size, "height": image_size}
)
@property
def image_start_tag(self) -> str:
return self.tokenizer.image_start_tag # type: ignore[attr-defined]
@property
def image_end_tag(self) -> str:
return self.tokenizer.image_end_tag # type: ignore[attr-defined]
@property
def image_pad_tag(self) -> str:
return self.tokenizer.image_pad_tag # type: ignore[attr-defined]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import ceil
import numpy as np
import torch
from mistral_common.tokens.tokenizers.audio import AudioEncoder
from transformers import BatchFeature, ProcessorMixin, TensorType
from transformers.audio_utils import AudioInput
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput
from vllm.tokenizers.mistral import MistralTokenizer
class MistralCommonFeatureExtractor:
"""
Provide a HF-compatible interface for
`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
"""
def __init__(self, audio_encoder: AudioEncoder) -> None:
self.audio_encoder = audio_encoder
@property
def sampling_rate(self):
return self.audio_encoder.audio_config.sampling_rate
@property
def frame_rate(self):
return self.audio_encoder.audio_config.frame_rate
def __call__(
self,
audios: AudioInput,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
audios_lst = [audios] if not isinstance(audios, list) else audios
audios_processed = list[torch.Tensor]()
for audio in audios_lst:
audio = np.asarray(audio, dtype=np.float32).ravel()
if not self.audio_encoder.audio_config.is_streaming:
audio = self.audio_encoder.pad(audio, self.sampling_rate)
audios_processed.append(torch.tensor(audio))
return BatchFeature(
{"audio_arrays": audios_processed}, tensor_type=return_tensors
)
def get_num_audio_tokens(self, audio_length: int) -> int:
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
class MistralCommonVoxtralProcessor(ProcessorMixin):
attributes = ["feature_extractor", "tokenizer"]
def __init__(self, tokenizer: MistralTokenizer) -> None:
self.tokenizer = tokenizer.transformers_tokenizer
self.feature_extractor = MistralCommonFeatureExtractor(
tokenizer.instruct.audio_encoder
)
self._audio_special_ids = self.feature_extractor.audio_encoder.special_ids
@property
def audio_token_id(self) -> int:
return self._audio_special_ids.audio
@property
def begin_audio_token_id(self) -> int:
return self._audio_special_ids.begin_audio
def __call__(
self,
images: ImageInput | None = None,
text: TextInput
| PreTokenizedInput
| list[TextInput]
| list[PreTokenizedInput]
| None = None,
videos: VideoInput | None = None,
audio: AudioInput | None = None,
**kwargs,
):
if images is None and text is None and videos is None and audio is None:
raise ValueError(
f"You need to provide at least one input to "
f"call {self.__class__.__name__}"
)
kwargs = self._merge_kwargs(
self.valid_processor_kwargs,
tokenizer_init_kwargs={},
**kwargs,
)
kwargs["text_kwargs"]["return_tensors"] = "pt"
kwargs["audio_kwargs"]["return_tensors"] = None # Avoid padding issue
attribute_to_kwargs = {
"tokenizer": (text, "text_kwargs"),
"image_processor": (images, "images_kwargs"),
"video_processor": (videos, "videos_kwargs"),
"feature_extractor": (audio, "audio_kwargs"),
}
outputs = {}
for attribute_name in self.attributes:
attribute = getattr(self, attribute_name, None)
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
if input_data is not None and attribute is not None:
attribute_output = attribute(input_data, **kwargs[input_kwargs])
outputs.update(attribute_output)
return BatchFeature(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment