Unverified Commit 982beae8 authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

Optimize nemotron VL image/video preprocessing (#40283)


Signed-off-by: default avatarmilesial <milesial@users.noreply.github.com>
Co-authored-by: default avatarmilesial <milesial@users.noreply.github.com>
parent 45232a45
......@@ -8,7 +8,6 @@
# --------------------------------------------------------
import math
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
......@@ -26,7 +25,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.model_executor.models.parakeet import ParakeetExtractor
from vllm.multimodal.evs import compute_retained_tokens_count
from vllm.multimodal.inputs import AudioItem
from vllm.multimodal.processing.processor import PromptUpdateDetails, _seq2tokens
from vllm.multimodal.processing.processor import PromptUpdateDetails
from vllm.tokenizers.hf import HfTokenizer
from .internvl import calculate_internvl_targets, get_internvl_target_ratios
......@@ -63,42 +62,50 @@ def calculate_timestamps(
return timestamps
def input_conditioner(x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.Tensor):
return (x - norm_mean) / norm_std
@torch.compile(dynamic=True)
def _bicubic_resize_and_normalize(
tensor: torch.Tensor,
size: tuple[int, int] | None = None,
norm_mean: torch.Tensor | None = None,
norm_std: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Permute NHWC→NCHW, optional bicubic resize, rescale + normalize.
Input must be a raw 4-D **NHWC** tensor.
def _bicubic_from_ndarray(
array: npt.NDArray[Any], *, size: tuple[int, int]
) -> torch.Tensor:
"""
Convert a 4D NHWC ndarray to NCHW and interpolate with bicubic.
Suppresses PyTorch's non-writable NumPy warning because interpolate copies,
and torch.from_numpy(array) is discarded at the end of function scope.
*size*: target ``(H, W)``; skips interpolation when ``None``.
*norm_mean* / *norm_std*: when both provided, fused
``(x/255 - mean) / std`` + dtype cast; otherwise ``x/255`` + cast.
"""
with warnings.catch_warnings():
msg = "The given NumPy array is not writ.*"
# Apparently, different versions of PyTorch use writable or writeable.
warnings.filterwarnings("ignore", message=msg, category=UserWarning)
tensor = torch.from_numpy(array)
assert tensor.ndim == 4, f"{tensor.ndim=}"
tensor = tensor.permute(0, 3, 1, 2)
return (
torch.nn.functional.interpolate(
tensor = tensor.permute(0, 3, 1, 2).to(dtype=torch.float32)
if size is not None:
tensor = torch.nn.functional.interpolate(
tensor, size=size, mode="bicubic", align_corners=False, antialias=True
)
/ 255.0
if norm_mean is not None and norm_std is not None:
return ((tensor / 255.0 - norm_mean) / norm_std).to(dtype=dtype).contiguous()
return (tensor / 255.0).to(dtype=dtype).contiguous()
def _pil_to_nhwc_tensor(image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a 4-D NHWC tensor suitable for compiled ops."""
array = np.asarray(
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
)
return torch.from_numpy(np.expand_dims(array, axis=0))
def dynamic_preprocess(
image,
image: Image.Image,
*,
image_size=512,
max_num_tiles=12,
use_thumbnail=True,
idx=0,
):
image_size: int = 512,
max_num_tiles: int = 12,
use_thumbnail: bool = True,
norm_mean: torch.Tensor | None = None,
norm_std: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
orig_width, orig_height = image.size
target_ratios = get_internvl_target_ratios(1, max_num_tiles)
......@@ -111,13 +118,15 @@ def dynamic_preprocess(
use_thumbnail=False,
)
image = np.asarray(
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
)
tensor = _pil_to_nhwc_tensor(image)
image = np.expand_dims(image, axis=0)
resized_img = _bicubic_from_ndarray(image, size=(target_height, target_width))
resized_img = _bicubic_resize_and_normalize(
tensor,
size=(target_height, target_width),
norm_mean=norm_mean,
norm_std=norm_std,
dtype=dtype,
)
B, C, H, W = resized_img.shape
hp, wp = H // image_size, W // image_size
patches = (
......@@ -127,30 +136,16 @@ def dynamic_preprocess(
)
if use_thumbnail and patches.shape[0] > 1:
thumb = _bicubic_from_ndarray(image, size=(image_size, image_size))
patches = torch.cat([patches, thumb], dim=0)
return list(patches)
def image_to_pixel_values(
image: Image.Image,
*,
input_size: int,
max_num: int,
use_thumbnail: bool,
idx: int,
) -> torch.Tensor:
images = dynamic_preprocess(
image,
image_size=input_size,
max_num_tiles=max_num,
use_thumbnail=use_thumbnail,
idx=idx,
thumb = _bicubic_resize_and_normalize(
tensor,
size=(image_size, image_size),
norm_mean=norm_mean,
norm_std=norm_std,
dtype=dtype,
)
patches = torch.cat([patches, thumb], dim=0)
pixel_values = torch.stack(images)
return pixel_values
return patches
def _compute_aspect_preserving_size(
......@@ -233,14 +228,16 @@ def video_to_pixel_values(
video_maintain_aspect_ratio: bool = False,
patch_size: int = 16,
downsample_ratio: float = 0.5,
norm_mean: torch.Tensor | None = None,
norm_std: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
# (num_frames, H, W, C) -> (num_frames, C, H, W)
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2)
"""Convert video ndarray (T, H, W, C) to normalized pixel tensor (T, C, H, W)."""
orig_h, orig_w = video.shape[1], video.shape[2]
size: tuple[int, int] | None = None
if video_target_num_patches is not None:
# Resize to target patch count (aspect-preserving or square).
orig_h, orig_w = video_tensor.shape[2], video_tensor.shape[3]
target_w, target_h, _ = get_video_target_size_and_feature_size(
tw, th, _ = get_video_target_size_and_feature_size(
orig_w=orig_w,
orig_h=orig_h,
target_patches=video_target_num_patches,
......@@ -248,14 +245,13 @@ def video_to_pixel_values(
patch_size=patch_size,
downsample_ratio=downsample_ratio,
)
if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w:
return _bicubic_from_ndarray(video, size=(target_h, target_w))
elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
return _bicubic_from_ndarray(video, size=(input_size, input_size))
if orig_h != th or orig_w != tw:
size = (th, tw)
elif orig_h != input_size or orig_w != input_size:
size = (input_size, input_size)
video_tensor = video_tensor / 255.0
return video_tensor
tensor = torch.from_numpy(video)
return _bicubic_resize_and_normalize(tensor, size, norm_mean, norm_std, dtype)
class DynamicResolutionImageTiler:
......@@ -343,6 +339,7 @@ class DynamicResolutionImageTiler:
self,
text_prompt_length: int,
images: list[Image.Image],
dtype: torch.dtype = torch.float32,
) -> tuple[list[torch.Tensor], list[int]]:
num_tokens_available = self.max_num_tokens_available(text_prompt_length)
params_per_image = self.compute_params(images, num_tokens_available)
......@@ -350,7 +347,7 @@ class DynamicResolutionImageTiler:
feature_sizes = []
images = []
for param in params_per_image:
for t in self.apply_params(param):
for t in self.apply_params(param, dtype=dtype):
assert t.ndim == 3, f"{t.ndim=}: expected 3 dim tensor"
images.append(t)
feature_sizes.append(param.num_embeddings)
......@@ -363,17 +360,23 @@ class DynamicResolutionImageTiler:
num_embeddings: int
patch_size: tuple[int, int]
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
def apply_params(
self,
params: DynamicResolutionParams,
dtype: torch.dtype = torch.float32,
) -> list[torch.Tensor]:
target_size = (
params.patch_size[1] * self._patch_size,
params.patch_size[0] * self._patch_size,
)
image = np.asarray(
params.media.convert("RGB") if params.media.mode != "RGB" else params.media,
dtype=np.uint8,
tensor = _pil_to_nhwc_tensor(params.media)
resized_img = _bicubic_resize_and_normalize(
tensor,
size=target_size,
norm_mean=self.norm_mean,
norm_std=self.norm_std,
dtype=dtype,
)
image = np.expand_dims(image, axis=0)
resized_img = _bicubic_from_ndarray(image, size=target_size)
return list(resized_img)
def process_media(
......@@ -619,6 +622,7 @@ class BaseNanoNemotronVLProcessor(ABC):
norm_mean=config.norm_mean,
norm_std=config.norm_std,
)
self.dtype: torch.dtype = getattr(config, "dtype", torch.float32)
@staticmethod
def use_dynamic_resolution(config: PretrainedConfig) -> bool:
......@@ -662,14 +666,16 @@ class BaseNanoNemotronVLProcessor(ABC):
max_num_tiles: int,
) -> list[torch.Tensor]:
return [
image_to_pixel_values(
dynamic_preprocess(
image,
input_size=self.image_size,
max_num=max_num_tiles,
image_size=self.image_size,
max_num_tiles=max_num_tiles,
use_thumbnail=self.use_thumbnail,
idx=idx,
norm_mean=self.norm_mean,
norm_std=self.norm_std,
dtype=self.dtype,
)
for idx, image in enumerate(images)
for image in images
]
def _preprocess_image(
......@@ -690,23 +696,22 @@ class BaseNanoNemotronVLProcessor(ABC):
pixel_values_lst, num_tokens_per_image = tiler._images_to_pixel_values_lst(
text_prompt_length=text_prompt_length,
images=images,
dtype=self.dtype,
)
imgs_sizes = [(pv.shape[-2], pv.shape[-1]) for pv in pixel_values_lst]
normalized = [
input_conditioner(img, tiler.norm_mean, tiler.norm_std)
for img in pixel_values_lst
]
image_num_patches = torch.tensor([1] * len(num_tokens_per_image))
image_inputs = {
"pixel_values_flat": normalized,
"pixel_values_flat": pixel_values_lst,
"imgs_sizes": imgs_sizes,
"num_tokens_per_image": num_tokens_per_image,
}
else:
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
image_num_patches = torch.tensor([len(item) for item in pixel_values_lst])
pixel_values_flat = input_conditioner(
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
pixel_values_flat = (
torch.cat(pixel_values_lst)
if len(pixel_values_lst) > 1
else pixel_values_lst[0]
)
image_inputs = {
"pixel_values_flat": pixel_values_flat,
......@@ -863,6 +868,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
def _videos_to_pixel_values_lst(
self,
videos: list[npt.NDArray],
*,
dtype: torch.dtype = torch.float32,
) -> list[torch.Tensor]:
return [
video_to_pixel_values(
......@@ -872,6 +879,9 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
video_maintain_aspect_ratio=self.video_maintain_aspect_ratio,
patch_size=self.config.patch_size,
downsample_ratio=self.config.downsample_ratio,
norm_mean=self.norm_mean,
norm_std=self.norm_std,
dtype=dtype,
)
for video in videos
]
......@@ -886,8 +896,10 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
videos_lst = [v[0] for v in videos]
video_metadata_lst = [v[1] for v in videos]
pixel_values_lst_video = self._videos_to_pixel_values_lst(
videos_lst,
dtype=self.dtype,
)
# We use frame duration in milliseconds (as integer) to ensure
......@@ -903,10 +915,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
metadata["frames_indices"] for metadata in video_metadata_lst
]
video_num_patches = torch.tensor([len(item) for item in pixel_values_lst_video])
# Normalization already fused into resize above.
# Skip the torch.cat copy when there is exactly one video
if len(pixel_values_lst_video) == 1:
pixel_values_flat = pixel_values_lst_video[0]
else:
pixel_values_flat = torch.cat(pixel_values_lst_video)
video_inputs = {
"pixel_values_flat_video": input_conditioner(
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
),
"pixel_values_flat_video": pixel_values_flat,
"video_num_patches": video_num_patches,
"frames_indices": frames_indices_lst,
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
......@@ -1168,20 +1185,21 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
for i, _ in enumerate(tokens_per_frame)
]
# Tokenize frame separator independently
frame_separators_tokenized = [
_seq2tokens(tokenizer, sep) for sep in frame_separators
]
# Batch-tokenize all frame separators at once — the HuggingFace
# tokenizers Rust backend parallelizes batch encoding across threads.
batch_encoded = tokenizer(
frame_separators,
add_special_tokens=False,
return_attention_mask=False,
)
frame_separators_tokenized: list[list[int]] = batch_encoded["input_ids"]
# Tokenize each component independently to avoid tokenizer merging tokens
# across boundaries. This ensures consistent tokenization regardless of
# num_tokens_per_frame values.
all_token_ids = []
for i, num_tokens in enumerate(tokens_per_frame):
frame_sep_token_ids = frame_separators_tokenized[i]
all_token_ids.extend(frame_sep_token_ids)
# Add pre-tokenized special tokens
all_token_ids.extend(frame_separators_tokenized[i])
all_token_ids.extend(img_start_token_ids)
all_token_ids.extend(img_context_token_ids * num_tokens)
all_token_ids.extend(img_end_token_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment