Unverified Commit 30108fc8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Refactor Step3-VL processor to HF style (#37579)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent e2d1c8b5
...@@ -39,7 +39,11 @@ from vllm.multimodal.processing import ( ...@@ -39,7 +39,11 @@ from vllm.multimodal.processing import (
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.step3_vl import Step3VisionEncoderConfig from vllm.transformers_utils.configs.step3_vl import Step3VisionEncoderConfig
from vllm.transformers_utils.processors.step3_vl import Step3VLProcessor from vllm.transformers_utils.processors.step3_vl import (
MAX_IMAGE_SIZE,
Step3VLImageProcessor,
Step3VLProcessor,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...@@ -86,21 +90,30 @@ Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingI ...@@ -86,21 +90,30 @@ Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingI
class Step3VLProcessingInfo(BaseProcessingInfo): class Step3VLProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs):
config = self.get_hf_config()
kwargs.setdefault(
"enable_patch",
getattr(config.vision_config, "enable_patch", True),
)
return Step3VLImageProcessor(**kwargs)
def get_hf_processor(self) -> Step3VLProcessor: def get_hf_processor(self) -> Step3VLProcessor:
return Step3VLProcessor( return Step3VLProcessor(
self.get_hf_config(), tokenizer=self.get_tokenizer(),
self.get_tokenizer(), image_processor=self.get_image_processor(),
) )
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None} return {"image": None}
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor() image_processor = self.get_image_processor()
return hf_processor.get_num_image_tokens( target_width, target_height = self.get_image_size_with_most_features()
self.get_image_size_with_most_features().width,
self.get_image_size_with_most_features().height, return image_processor.get_num_image_tokens(target_width, target_height)
)
def get_mm_max_tokens_per_item( def get_mm_max_tokens_per_item(
self, self,
...@@ -110,20 +123,7 @@ class Step3VLProcessingInfo(BaseProcessingInfo): ...@@ -110,20 +123,7 @@ class Step3VLProcessingInfo(BaseProcessingInfo):
return {"image": self.get_max_image_tokens()} return {"image": self.get_max_image_tokens()}
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
return ImageSize(3024, 3024) return ImageSize(MAX_IMAGE_SIZE, MAX_IMAGE_SIZE)
def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int:
if len(mm_data) != 1 or "image" not in mm_data:
raise ValueError("mm_data could only contain one key 'image' for steo1o")
image_data = mm_data["image"]
if not isinstance(image_data, (list, tuple)):
image_data = [image_data]
return sum(
self.get_hf_processor().get_num_image_tokens(img.width, img.height)
for img in image_data
)
class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]): class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
...@@ -165,13 +165,11 @@ class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]) ...@@ -165,13 +165,11 @@ class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo])
def get_replacement_step1o(item_idx: int): def get_replacement_step1o(item_idx: int):
out_item = out_mm_kwargs["image"][item_idx] out_item = out_mm_kwargs["image"][item_idx]
num_patches = int(out_item["num_patches"].data) num_patches = int(out_item["num_patches"].data)
if num_patches > 0: patch_newline_mask = out_item["patch_newline_mask"].data
patch_newline_mask = out_item["patch_newline_mask"].data image_repl_ids = hf_processor.get_image_repl_feature_ids(
image_repl_ids = hf_processor._get_image_repl_features( 1, num_patches, patch_newline_mask.tolist()
1, num_patches, patch_newline_mask.tolist() )
)[1]
else:
image_repl_ids = hf_processor._get_image_repl_features(1, 0, None)[1]
return PromptUpdateDetails.select_token_id( return PromptUpdateDetails.select_token_id(
seq=image_repl_ids, seq=image_repl_ids,
embed_token_id=image_placeholder_token_id, embed_token_id=image_placeholder_token_id,
......
...@@ -558,6 +558,7 @@ class InternVLProcessor(ProcessorMixin): ...@@ -558,6 +558,7 @@ class InternVLProcessor(ProcessorMixin):
else: else:
text_inputs = {} text_inputs = {}
combined_outputs = {**text_inputs, **image_inputs, **video_inputs} return BatchFeature(
data={**text_inputs, **image_inputs, **video_inputs},
return BatchFeature(combined_outputs, tensor_type=return_tensors) tensor_type=return_tensors,
)
...@@ -19,7 +19,6 @@ class KimiK25Processor(ProcessorMixin): ...@@ -19,7 +19,6 @@ class KimiK25Processor(ProcessorMixin):
self.media_token_id = media_token_id self.media_token_id = media_token_id
assert self.media_token_id is not None assert self.media_token_id is not None
# We do not support str input for text here
def __call__( def __call__(
self, self,
vision_chunks: list[VisionChunk] | None = None, vision_chunks: list[VisionChunk] | None = None,
......
...@@ -8,13 +8,13 @@ import torch ...@@ -8,13 +8,13 @@ import torch
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType from transformers import BatchFeature, ProcessorMixin, TensorType
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
MAX_IMAGE_SIZE: int = 3024 MAX_IMAGE_SIZE: int = 3024
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[bool] | None] ImageWithPatches = tuple[Image.Image, list[Image.Image], list[bool]]
class Step3VisionProcessor: class Step3VisionProcessor:
...@@ -185,7 +185,7 @@ class ImagePatcher: ...@@ -185,7 +185,7 @@ class ImagePatcher:
def __call__( def __call__(
self, img: Image.Image self, img: Image.Image
) -> tuple[Image.Image, list[Image.Image], list[bool] | None]: ) -> tuple[Image.Image, list[Image.Image], list[bool]]:
img_width, img_height = img.size img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_padding( new_img_width, new_img_height = self.get_image_size_for_padding(
img_width, img_height img_width, img_height
...@@ -203,7 +203,7 @@ class ImagePatcher: ...@@ -203,7 +203,7 @@ class ImagePatcher:
) )
if window_size == 0 or not self.enable_patch: if window_size == 0 or not self.enable_patch:
return img, [], None return img, [], []
else: else:
new_img_width, new_img_height = self.get_image_size_for_crop( new_img_width, new_img_height = self.get_image_size_for_crop(
new_img_width, new_img_height, window_size new_img_width, new_img_height, window_size
...@@ -236,43 +236,28 @@ class ImagePatcher: ...@@ -236,43 +236,28 @@ class ImagePatcher:
return ( return (
img, img,
patches, patches,
[i in newlines for i in range(len(patches))] [i in newlines for i in range(len(patches))],
if len(patches) > 0
else None,
) )
class Step3VLProcessor: class Step3VLImageProcessor:
def __init__( def __init__(
self, self,
config: PretrainedConfig, image_size: int = 728,
tokenizer: TokenizerLike, patch_size: int = 504,
num_image_feature_size: int = 169,
num_patch_feature_size: int = 81,
enable_patch: bool = True,
) -> None: ) -> None:
super().__init__() self.image_size = image_size
self.patch_size = patch_size
self.config = config self.num_image_feature_size = num_image_feature_size
self.tokenizer = tokenizer self.num_patch_feature_size = num_patch_feature_size
self.image_size = 728
self.patch_size = 504
self.image_preprocessor = Step3VisionProcessor( self.image_preprocessor = Step3VisionProcessor(
self.image_size, "bilinear", self.patch_size image_size, "bilinear", patch_size
) )
self.num_image_feature_size = 169
self.num_patch_feature_size = 81
self.image_token = "<im_patch>"
self.image_feature_placeholder = self.image_token * self.num_image_feature_size
self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size
# Respect vision config switch to enable/disable patch extraction.
# For video understanding, it's preferable to disable patch.
enable_patch = getattr(self.config.vision_config, "enable_patch", True)
self.patcher = ImagePatcher(enable_patch=enable_patch) self.patcher = ImagePatcher(enable_patch=enable_patch)
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[self.image_token]
def get_num_image_tokens(self, img_width: int, img_height: int) -> int: def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height) num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height)
...@@ -299,58 +284,168 @@ class Step3VLProcessor: ...@@ -299,58 +284,168 @@ class Step3VLProcessor:
for img in images for img in images
] ]
def _get_patch_repl( def __call__(
self,
images: Image.Image | list[Image.Image] | None = None,
return_tensors: str | TensorType | None = None,
) -> BatchFeature:
if images is None:
images = []
if not isinstance(images, list):
images = [images]
split_images_data = self._split_images(images)
pixel_values_lst = []
patch_pixel_values_lst = []
patch_newline_mask_lst = []
num_patches = []
for raw_img, img_patches, patch_newline_mask in split_images_data:
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
num_patches.append(len(img_patches))
patch_pixel_values_lst.extend(
self._convert_images_to_pixel_values(img_patches, is_patch=True)
)
patch_newline_mask_lst.extend(patch_newline_mask)
pixel_values = torch.cat(pixel_values_lst)
patch_size = self.patch_size
image_inputs = {
"pixel_values": pixel_values,
"num_patches": num_patches,
"patch_pixel_values": (
torch.cat(patch_pixel_values_lst)
if patch_pixel_values_lst
else pixel_values.new_empty((0, 3, patch_size, patch_size))
),
"patch_newline_mask": torch.tensor(
patch_newline_mask_lst, dtype=torch.bool
),
}
return BatchFeature(image_inputs, tensor_type=return_tensors)
class Step3VLProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
def __init__(
self,
image_processor: Step3VLImageProcessor,
tokenizer: TokenizerLike,
) -> None:
self.image_processor = image_processor
self.tokenizer = tokenizer
self.image_start_token = image_start_token = "<im_start>"
self.image_end_token = image_end_token = "<im_end>"
self.patch_start_token = patch_start_token = "<patch_start>"
self.patch_end_token = patch_end_token = "<patch_end>"
self.patch_newline_token = patch_newline_token = "<patch_newline>"
self.image_start_token_id = tokenizer.convert_tokens_to_ids(image_start_token)
self.image_end_token_id = tokenizer.convert_tokens_to_ids(image_end_token)
self.patch_start_token_id = tokenizer.convert_tokens_to_ids(patch_start_token)
self.patch_end_token_id = tokenizer.convert_tokens_to_ids(patch_end_token)
self.patch_newline_token_id = tokenizer.convert_tokens_to_ids(
patch_newline_token
)
self.image_token = image_token = "<im_patch>"
self.image_feature_tokens = image_token * image_processor.num_image_feature_size
self.patch_feature_tokens = image_token * image_processor.num_patch_feature_size
self.image_token_id = image_token_id = tokenizer.convert_tokens_to_ids(
image_token
)
self.image_feature_token_ids = [
image_token_id
] * image_processor.num_image_feature_size
self.patch_feature_token_ids = [
image_token_id
] * image_processor.num_patch_feature_size
def _get_patch_repl_text(
self, self,
num_patches: int, num_patches: int,
patch_newline_mask: list[bool] | None, patch_newline_mask: list[bool],
) -> tuple[str, list[int]]: ) -> str:
text = "" assert len(patch_newline_mask) == num_patches
token_ids = []
parts = []
for i in range(num_patches): for i in range(num_patches):
assert ( parts.extend(
patch_newline_mask is not None [
and len(patch_newline_mask) == num_patches self.patch_start_token,
self.patch_feature_tokens,
self.patch_end_token,
]
) )
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>" if patch_newline_mask[i]:
token_ids.extend( parts.append(self.patch_newline_token)
[self.tokenizer.convert_tokens_to_ids("<patch_start>")]
+ [self.image_token_id] * self.num_patch_feature_size return "".join(parts)
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")]
def _get_patch_repl_ids(
self,
num_patches: int,
patch_newline_mask: list[bool],
) -> list[int]:
assert len(patch_newline_mask) == num_patches
parts = []
for i in range(num_patches):
parts.extend(
[
self.patch_start_token_id,
*self.patch_feature_token_ids,
self.patch_end_token_id,
]
) )
if patch_newline_mask and patch_newline_mask[i]: if patch_newline_mask[i]:
text += "<patch_newline>" parts.append(self.patch_newline_token_id)
token_ids.append(
self.tokenizer.convert_tokens_to_ids("<patch_newline>")
)
return text, token_ids
def _get_image_repl( return parts
def _get_image_repl_text(
self, self,
num_images: int, num_images: int,
) -> tuple[str, list[int]]: ) -> str:
text = f"<im_start>{self.image_feature_placeholder}<im_end>" parts = [
token_ids = ( self.image_start_token,
[self.tokenizer.convert_tokens_to_ids("<im_start>")] self.image_feature_tokens,
+ [self.image_token_id] * self.num_image_feature_size self.image_end_token,
+ [self.tokenizer.convert_tokens_to_ids("<im_end>")] ] * num_images
)
return text * num_images, token_ids * num_images return "".join(parts)
def _get_image_repl_ids(
self,
num_images: int,
) -> list[int]:
part = [
self.image_start_token_id,
*self.image_feature_token_ids,
self.image_end_token_id,
]
return part * num_images
def _get_image_repl_features( def get_image_repl_feature_text(
self, self,
num_images: int, num_images: int,
num_patches: int, num_patches: int,
patch_new_line_idx: list[bool] | None, patch_new_line_idx: list[bool],
) -> tuple[str, list[int]]: ) -> str:
if num_patches > 0: patch_repl = self._get_patch_repl_text(num_patches, patch_new_line_idx)
patch_repl, patch_repl_ids = self._get_patch_repl( image_repl = self._get_image_repl_text(num_images)
num_patches, patch_new_line_idx return patch_repl + image_repl
)
else: def get_image_repl_feature_ids(
patch_repl = "" self,
patch_repl_ids = [] num_images: int,
image_repl, image_repl_ids = self._get_image_repl(num_images) num_patches: int,
return patch_repl + image_repl, patch_repl_ids + image_repl_ids patch_new_line_idx: list[bool],
) -> list[int]:
patch_repl = self._get_patch_repl_ids(num_patches, patch_new_line_idx)
image_repl = self._get_image_repl_ids(num_images)
return patch_repl + image_repl
def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str: def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str:
parts = text.split(placeholder) parts = text.split(placeholder)
...@@ -373,69 +468,44 @@ class Step3VLProcessor: ...@@ -373,69 +468,44 @@ class Step3VLProcessor:
images: Image.Image | list[Image.Image] | None = None, images: Image.Image | list[Image.Image] | None = None,
return_tensors: str | TensorType | None = None, return_tensors: str | TensorType | None = None,
) -> BatchFeature: ) -> BatchFeature:
if text is None: if images is not None:
text = [] image_inputs = self.image_processor(
if not isinstance(text, list): images=images,
text = [text] return_tensors=return_tensors,
if images is None: )
images = [] num_patches = image_inputs["num_patches"]
if not isinstance(images, list): patch_newline_mask = image_inputs["patch_newline_mask"]
images = [images]
if len(images) == 0:
image_inputs = {}
text_inputs = self.tokenizer(text)
else: else:
split_images_data = self._split_images(images) image_inputs = {}
pixel_values_lst = []
patch_pixel_values_lst = []
patch_newline_mask_lst = []
image_repl_str_lst = []
image_repl_ids_lst = []
num_patches = [] num_patches = []
for raw_img, img_patches, patch_newline_mask in split_images_data: patch_newline_mask = []
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
if text is not None:
if len(img_patches) > 0: if not isinstance(text, list):
patch_pixel_values_lst.extend( text = [text]
self._convert_images_to_pixel_values(img_patches, is_patch=True)
if image_inputs:
image_token = self.image_token
image_repl_str_lst = []
start = 0
for n_patches in num_patches:
image_repl_str = self.get_image_repl_feature_text(
1, n_patches, patch_newline_mask[start : start + n_patches]
) )
num_patches.append(len(img_patches)) image_repl_str_lst.append(image_repl_str)
image_repl_str, image_repl_ids = self._get_image_repl_features( start += n_patches
1, len(img_patches), patch_newline_mask
) text = [
image_repl_str_lst.append(image_repl_str) self.replace_placeholder(t, image_token, image_repl_str_lst)
image_repl_ids_lst.extend(image_repl_ids) for t in text
]
if patch_newline_mask is not None:
patch_newline_mask_lst.extend(patch_newline_mask)
pixel_values = torch.cat(pixel_values_lst)
patch_size = self.patch_size
image_inputs = {
"pixel_values": pixel_values,
"num_patches": num_patches,
"patch_pixel_values": (
torch.cat(patch_pixel_values_lst)
if patch_pixel_values_lst
else pixel_values.new_empty((0, 3, patch_size, patch_size))
),
"patch_newline_mask": torch.tensor(
patch_newline_mask_lst, dtype=torch.bool
),
}
text = [
self.replace_placeholder(t, self.image_token, image_repl_str_lst)
for t in text
]
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
else:
text_inputs = {}
return BatchFeature( return BatchFeature(
{ data={**text_inputs, **image_inputs},
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors, tensor_type=return_tensors,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment