Unverified Commit 982beae8 authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

Optimize nemotron VL image/video preprocessing (#40283)


Signed-off-by: default avatarmilesial <milesial@users.noreply.github.com>
Co-authored-by: default avatarmilesial <milesial@users.noreply.github.com>
parent 45232a45
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
# -------------------------------------------------------- # --------------------------------------------------------
import math import math
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
...@@ -26,7 +25,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType ...@@ -26,7 +25,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.model_executor.models.parakeet import ParakeetExtractor from vllm.model_executor.models.parakeet import ParakeetExtractor
from vllm.multimodal.evs import compute_retained_tokens_count from vllm.multimodal.evs import compute_retained_tokens_count
from vllm.multimodal.inputs import AudioItem from vllm.multimodal.inputs import AudioItem
from vllm.multimodal.processing.processor import PromptUpdateDetails, _seq2tokens from vllm.multimodal.processing.processor import PromptUpdateDetails
from vllm.tokenizers.hf import HfTokenizer from vllm.tokenizers.hf import HfTokenizer
from .internvl import calculate_internvl_targets, get_internvl_target_ratios from .internvl import calculate_internvl_targets, get_internvl_target_ratios
...@@ -63,42 +62,50 @@ def calculate_timestamps( ...@@ -63,42 +62,50 @@ def calculate_timestamps(
return timestamps return timestamps
def input_conditioner(x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.Tensor): @torch.compile(dynamic=True)
return (x - norm_mean) / norm_std def _bicubic_resize_and_normalize(
tensor: torch.Tensor,
size: tuple[int, int] | None = None,
norm_mean: torch.Tensor | None = None,
norm_std: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Permute NHWC→NCHW, optional bicubic resize, rescale + normalize.
Input must be a raw 4-D **NHWC** tensor.
def _bicubic_from_ndarray( *size*: target ``(H, W)``; skips interpolation when ``None``.
array: npt.NDArray[Any], *, size: tuple[int, int] *norm_mean* / *norm_std*: when both provided, fused
) -> torch.Tensor: ``(x/255 - mean) / std`` + dtype cast; otherwise ``x/255`` + cast.
"""
Convert a 4D NHWC ndarray to NCHW and interpolate with bicubic.
Suppresses PyTorch's non-writable NumPy warning because interpolate copies,
and torch.from_numpy(array) is discarded at the end of function scope.
""" """
tensor = tensor.permute(0, 3, 1, 2).to(dtype=torch.float32)
with warnings.catch_warnings(): if size is not None:
msg = "The given NumPy array is not writ.*" tensor = torch.nn.functional.interpolate(
# Apparently, different versions of PyTorch use writable or writeable.
warnings.filterwarnings("ignore", message=msg, category=UserWarning)
tensor = torch.from_numpy(array)
assert tensor.ndim == 4, f"{tensor.ndim=}"
tensor = tensor.permute(0, 3, 1, 2)
return (
torch.nn.functional.interpolate(
tensor, size=size, mode="bicubic", align_corners=False, antialias=True tensor, size=size, mode="bicubic", align_corners=False, antialias=True
) )
/ 255.0 if norm_mean is not None and norm_std is not None:
return ((tensor / 255.0 - norm_mean) / norm_std).to(dtype=dtype).contiguous()
return (tensor / 255.0).to(dtype=dtype).contiguous()
def _pil_to_nhwc_tensor(image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a 4-D NHWC tensor suitable for compiled ops."""
array = np.asarray(
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
) )
return torch.from_numpy(np.expand_dims(array, axis=0))
def dynamic_preprocess( def dynamic_preprocess(
image, image: Image.Image,
*, *,
image_size=512, image_size: int = 512,
max_num_tiles=12, max_num_tiles: int = 12,
use_thumbnail=True, use_thumbnail: bool = True,
idx=0, norm_mean: torch.Tensor | None = None,
): norm_std: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
orig_width, orig_height = image.size orig_width, orig_height = image.size
target_ratios = get_internvl_target_ratios(1, max_num_tiles) target_ratios = get_internvl_target_ratios(1, max_num_tiles)
...@@ -111,13 +118,15 @@ def dynamic_preprocess( ...@@ -111,13 +118,15 @@ def dynamic_preprocess(
use_thumbnail=False, use_thumbnail=False,
) )
image = np.asarray( tensor = _pil_to_nhwc_tensor(image)
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
)
image = np.expand_dims(image, axis=0) resized_img = _bicubic_resize_and_normalize(
tensor,
resized_img = _bicubic_from_ndarray(image, size=(target_height, target_width)) size=(target_height, target_width),
norm_mean=norm_mean,
norm_std=norm_std,
dtype=dtype,
)
B, C, H, W = resized_img.shape B, C, H, W = resized_img.shape
hp, wp = H // image_size, W // image_size hp, wp = H // image_size, W // image_size
patches = ( patches = (
...@@ -127,30 +136,16 @@ def dynamic_preprocess( ...@@ -127,30 +136,16 @@ def dynamic_preprocess(
) )
if use_thumbnail and patches.shape[0] > 1: if use_thumbnail and patches.shape[0] > 1:
thumb = _bicubic_from_ndarray(image, size=(image_size, image_size)) thumb = _bicubic_resize_and_normalize(
tensor,
size=(image_size, image_size),
norm_mean=norm_mean,
norm_std=norm_std,
dtype=dtype,
)
patches = torch.cat([patches, thumb], dim=0) patches = torch.cat([patches, thumb], dim=0)
return list(patches) return patches
def image_to_pixel_values(
image: Image.Image,
*,
input_size: int,
max_num: int,
use_thumbnail: bool,
idx: int,
) -> torch.Tensor:
images = dynamic_preprocess(
image,
image_size=input_size,
max_num_tiles=max_num,
use_thumbnail=use_thumbnail,
idx=idx,
)
pixel_values = torch.stack(images)
return pixel_values
def _compute_aspect_preserving_size( def _compute_aspect_preserving_size(
...@@ -233,14 +228,16 @@ def video_to_pixel_values( ...@@ -233,14 +228,16 @@ def video_to_pixel_values(
video_maintain_aspect_ratio: bool = False, video_maintain_aspect_ratio: bool = False,
patch_size: int = 16, patch_size: int = 16,
downsample_ratio: float = 0.5, downsample_ratio: float = 0.5,
norm_mean: torch.Tensor | None = None,
norm_std: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor: ) -> torch.Tensor:
# (num_frames, H, W, C) -> (num_frames, C, H, W) """Convert video ndarray (T, H, W, C) to normalized pixel tensor (T, C, H, W)."""
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2) orig_h, orig_w = video.shape[1], video.shape[2]
size: tuple[int, int] | None = None
if video_target_num_patches is not None: if video_target_num_patches is not None:
# Resize to target patch count (aspect-preserving or square). tw, th, _ = get_video_target_size_and_feature_size(
orig_h, orig_w = video_tensor.shape[2], video_tensor.shape[3]
target_w, target_h, _ = get_video_target_size_and_feature_size(
orig_w=orig_w, orig_w=orig_w,
orig_h=orig_h, orig_h=orig_h,
target_patches=video_target_num_patches, target_patches=video_target_num_patches,
...@@ -248,14 +245,13 @@ def video_to_pixel_values( ...@@ -248,14 +245,13 @@ def video_to_pixel_values(
patch_size=patch_size, patch_size=patch_size,
downsample_ratio=downsample_ratio, downsample_ratio=downsample_ratio,
) )
if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w: if orig_h != th or orig_w != tw:
return _bicubic_from_ndarray(video, size=(target_h, target_w)) size = (th, tw)
elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size: elif orig_h != input_size or orig_w != input_size:
return _bicubic_from_ndarray(video, size=(input_size, input_size)) size = (input_size, input_size)
video_tensor = video_tensor / 255.0 tensor = torch.from_numpy(video)
return _bicubic_resize_and_normalize(tensor, size, norm_mean, norm_std, dtype)
return video_tensor
class DynamicResolutionImageTiler: class DynamicResolutionImageTiler:
...@@ -343,6 +339,7 @@ class DynamicResolutionImageTiler: ...@@ -343,6 +339,7 @@ class DynamicResolutionImageTiler:
self, self,
text_prompt_length: int, text_prompt_length: int,
images: list[Image.Image], images: list[Image.Image],
dtype: torch.dtype = torch.float32,
) -> tuple[list[torch.Tensor], list[int]]: ) -> tuple[list[torch.Tensor], list[int]]:
num_tokens_available = self.max_num_tokens_available(text_prompt_length) num_tokens_available = self.max_num_tokens_available(text_prompt_length)
params_per_image = self.compute_params(images, num_tokens_available) params_per_image = self.compute_params(images, num_tokens_available)
...@@ -350,7 +347,7 @@ class DynamicResolutionImageTiler: ...@@ -350,7 +347,7 @@ class DynamicResolutionImageTiler:
feature_sizes = [] feature_sizes = []
images = [] images = []
for param in params_per_image: for param in params_per_image:
for t in self.apply_params(param): for t in self.apply_params(param, dtype=dtype):
assert t.ndim == 3, f"{t.ndim=}: expected 3 dim tensor" assert t.ndim == 3, f"{t.ndim=}: expected 3 dim tensor"
images.append(t) images.append(t)
feature_sizes.append(param.num_embeddings) feature_sizes.append(param.num_embeddings)
...@@ -363,17 +360,23 @@ class DynamicResolutionImageTiler: ...@@ -363,17 +360,23 @@ class DynamicResolutionImageTiler:
num_embeddings: int num_embeddings: int
patch_size: tuple[int, int] patch_size: tuple[int, int]
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]: def apply_params(
self,
params: DynamicResolutionParams,
dtype: torch.dtype = torch.float32,
) -> list[torch.Tensor]:
target_size = ( target_size = (
params.patch_size[1] * self._patch_size, params.patch_size[1] * self._patch_size,
params.patch_size[0] * self._patch_size, params.patch_size[0] * self._patch_size,
) )
image = np.asarray( tensor = _pil_to_nhwc_tensor(params.media)
params.media.convert("RGB") if params.media.mode != "RGB" else params.media, resized_img = _bicubic_resize_and_normalize(
dtype=np.uint8, tensor,
size=target_size,
norm_mean=self.norm_mean,
norm_std=self.norm_std,
dtype=dtype,
) )
image = np.expand_dims(image, axis=0)
resized_img = _bicubic_from_ndarray(image, size=target_size)
return list(resized_img) return list(resized_img)
def process_media( def process_media(
...@@ -619,6 +622,7 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -619,6 +622,7 @@ class BaseNanoNemotronVLProcessor(ABC):
norm_mean=config.norm_mean, norm_mean=config.norm_mean,
norm_std=config.norm_std, norm_std=config.norm_std,
) )
self.dtype: torch.dtype = getattr(config, "dtype", torch.float32)
@staticmethod @staticmethod
def use_dynamic_resolution(config: PretrainedConfig) -> bool: def use_dynamic_resolution(config: PretrainedConfig) -> bool:
...@@ -662,14 +666,16 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -662,14 +666,16 @@ class BaseNanoNemotronVLProcessor(ABC):
max_num_tiles: int, max_num_tiles: int,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
return [ return [
image_to_pixel_values( dynamic_preprocess(
image, image,
input_size=self.image_size, image_size=self.image_size,
max_num=max_num_tiles, max_num_tiles=max_num_tiles,
use_thumbnail=self.use_thumbnail, use_thumbnail=self.use_thumbnail,
idx=idx, norm_mean=self.norm_mean,
norm_std=self.norm_std,
dtype=self.dtype,
) )
for idx, image in enumerate(images) for image in images
] ]
def _preprocess_image( def _preprocess_image(
...@@ -690,23 +696,22 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -690,23 +696,22 @@ class BaseNanoNemotronVLProcessor(ABC):
pixel_values_lst, num_tokens_per_image = tiler._images_to_pixel_values_lst( pixel_values_lst, num_tokens_per_image = tiler._images_to_pixel_values_lst(
text_prompt_length=text_prompt_length, text_prompt_length=text_prompt_length,
images=images, images=images,
dtype=self.dtype,
) )
imgs_sizes = [(pv.shape[-2], pv.shape[-1]) for pv in pixel_values_lst] imgs_sizes = [(pv.shape[-2], pv.shape[-1]) for pv in pixel_values_lst]
normalized = [
input_conditioner(img, tiler.norm_mean, tiler.norm_std)
for img in pixel_values_lst
]
image_num_patches = torch.tensor([1] * len(num_tokens_per_image)) image_num_patches = torch.tensor([1] * len(num_tokens_per_image))
image_inputs = { image_inputs = {
"pixel_values_flat": normalized, "pixel_values_flat": pixel_values_lst,
"imgs_sizes": imgs_sizes, "imgs_sizes": imgs_sizes,
"num_tokens_per_image": num_tokens_per_image, "num_tokens_per_image": num_tokens_per_image,
} }
else: else:
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
image_num_patches = torch.tensor([len(item) for item in pixel_values_lst]) image_num_patches = torch.tensor([len(item) for item in pixel_values_lst])
pixel_values_flat = input_conditioner( pixel_values_flat = (
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std torch.cat(pixel_values_lst)
if len(pixel_values_lst) > 1
else pixel_values_lst[0]
) )
image_inputs = { image_inputs = {
"pixel_values_flat": pixel_values_flat, "pixel_values_flat": pixel_values_flat,
...@@ -863,6 +868,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -863,6 +868,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
def _videos_to_pixel_values_lst( def _videos_to_pixel_values_lst(
self, self,
videos: list[npt.NDArray], videos: list[npt.NDArray],
*,
dtype: torch.dtype = torch.float32,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
return [ return [
video_to_pixel_values( video_to_pixel_values(
...@@ -872,6 +879,9 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -872,6 +879,9 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
video_maintain_aspect_ratio=self.video_maintain_aspect_ratio, video_maintain_aspect_ratio=self.video_maintain_aspect_ratio,
patch_size=self.config.patch_size, patch_size=self.config.patch_size,
downsample_ratio=self.config.downsample_ratio, downsample_ratio=self.config.downsample_ratio,
norm_mean=self.norm_mean,
norm_std=self.norm_std,
dtype=dtype,
) )
for video in videos for video in videos
] ]
...@@ -886,8 +896,10 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -886,8 +896,10 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
videos_lst = [v[0] for v in videos] videos_lst = [v[0] for v in videos]
video_metadata_lst = [v[1] for v in videos] video_metadata_lst = [v[1] for v in videos]
pixel_values_lst_video = self._videos_to_pixel_values_lst( pixel_values_lst_video = self._videos_to_pixel_values_lst(
videos_lst, videos_lst,
dtype=self.dtype,
) )
# We use frame duration in milliseconds (as integer) to ensure # We use frame duration in milliseconds (as integer) to ensure
...@@ -903,10 +915,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -903,10 +915,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
metadata["frames_indices"] for metadata in video_metadata_lst metadata["frames_indices"] for metadata in video_metadata_lst
] ]
video_num_patches = torch.tensor([len(item) for item in pixel_values_lst_video]) video_num_patches = torch.tensor([len(item) for item in pixel_values_lst_video])
# Normalization already fused into resize above.
# Skip the torch.cat copy when there is exactly one video
if len(pixel_values_lst_video) == 1:
pixel_values_flat = pixel_values_lst_video[0]
else:
pixel_values_flat = torch.cat(pixel_values_lst_video)
video_inputs = { video_inputs = {
"pixel_values_flat_video": input_conditioner( "pixel_values_flat_video": pixel_values_flat,
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
),
"video_num_patches": video_num_patches, "video_num_patches": video_num_patches,
"frames_indices": frames_indices_lst, "frames_indices": frames_indices_lst,
"frame_duration_ms": torch.tensor(frame_duration_ms_lst), "frame_duration_ms": torch.tensor(frame_duration_ms_lst),
...@@ -1168,20 +1185,21 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -1168,20 +1185,21 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
for i, _ in enumerate(tokens_per_frame) for i, _ in enumerate(tokens_per_frame)
] ]
# Tokenize frame separator independently # Batch-tokenize all frame separators at once — the HuggingFace
frame_separators_tokenized = [ # tokenizers Rust backend parallelizes batch encoding across threads.
_seq2tokens(tokenizer, sep) for sep in frame_separators batch_encoded = tokenizer(
] frame_separators,
add_special_tokens=False,
return_attention_mask=False,
)
frame_separators_tokenized: list[list[int]] = batch_encoded["input_ids"]
# Tokenize each component independently to avoid tokenizer merging tokens # Tokenize each component independently to avoid tokenizer merging tokens
# across boundaries. This ensures consistent tokenization regardless of # across boundaries. This ensures consistent tokenization regardless of
# num_tokens_per_frame values. # num_tokens_per_frame values.
all_token_ids = [] all_token_ids = []
for i, num_tokens in enumerate(tokens_per_frame): for i, num_tokens in enumerate(tokens_per_frame):
frame_sep_token_ids = frame_separators_tokenized[i] all_token_ids.extend(frame_separators_tokenized[i])
all_token_ids.extend(frame_sep_token_ids)
# Add pre-tokenized special tokens
all_token_ids.extend(img_start_token_ids) all_token_ids.extend(img_start_token_ids)
all_token_ids.extend(img_context_token_ids * num_tokens) all_token_ids.extend(img_context_token_ids * num_tokens)
all_token_ids.extend(img_end_token_ids) all_token_ids.extend(img_end_token_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment