Unverified Commit b7332b05 authored by nvnbagrov's avatar nvnbagrov Committed by GitHub
Browse files

[Model] Nano Nemotron VL - fast media preprocessing (#35657)


Signed-off-by: default avatarNatan Bagrov <nbagrov@nvidia.com>
parent 40077ea3
......@@ -17,11 +17,11 @@ from functools import cached_property
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import einops
import numpy as np
import numpy.typing as npt
import regex as re
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
......@@ -214,7 +214,12 @@ NanoNemotronVLVideoInputs: TypeAlias = (
def dynamic_preprocess(
image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0
image,
*,
image_size=512,
max_num_tiles=12,
use_thumbnail=True,
idx=0,
):
orig_width, orig_height = image.size
......@@ -227,35 +232,44 @@ def dynamic_preprocess(
image_size=image_size,
use_thumbnail=False,
)
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
processed_images = [
img.convert("RGB") if img.mode != "RGB" else img for img in processed_images
]
processed_images = [
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)(
img
image = np.asarray(
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
)
for img in processed_images
]
processed_images = [T.ToTensor()(img) for img in processed_images]
return processed_images
image = torch.from_numpy(image).unsqueeze(0) # (1, H, W, 3)
image = image.permute(0, 3, 1, 2) # (1, 3, H, W)
resized_img = torch.nn.functional.interpolate(
image,
size=(target_height, target_width),
mode="bicubic",
align_corners=False,
antialias=True,
)
B, C, H, W = resized_img.shape
hp, wp = H // image_size, W // image_size
patches = (
resized_img.reshape(B, C, hp, image_size, wp, image_size)
.permute(0, 2, 4, 1, 3, 5)
.reshape(B * hp * wp, C, image_size, image_size)
/ 255.0
)
if use_thumbnail and patches.shape[0] > 1:
thumb = (
torch.nn.functional.interpolate(
image,
size=(image_size, image_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
/ 255.0
)
patches = torch.cat([patches, thumb], dim=0)
return list(patches)
def image_to_pixel_values(
......@@ -287,22 +301,21 @@ def video_to_pixel_values(
) -> torch.Tensor:
assert max_num_tiles == 1, "Video modality always uses one tile"
# Convert each frame to a single resized tile tensor consistent
# with image path
frames_tensors: list[torch.Tensor] = []
for frame in video:
pil_frame = dynamic_preprocess(
Image.fromarray(frame, mode="RGB"),
image_size=input_size,
max_num_tiles=max_num_tiles,
use_thumbnail=use_thumbnail,
idx=0,
# (num_frames, H, W, C) -> (num_frames, C, H, W)
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2)
if video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
video_tensor = torch.nn.functional.interpolate(
video_tensor,
size=(input_size, input_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
# dynamic_preprocess returns tensors already; take the single tile
assert len(pil_frame) >= 1
frames_tensors.append(pil_frame[-1])
return torch.stack(frames_tensors)
video_tensor = video_tensor / 255.0
return video_tensor
def input_conditioner(x, norm_mean, norm_std):
......@@ -346,12 +359,6 @@ class DynamicResolutionImageTiler:
self._factor_max = factor_max
self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1)
self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1)
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(),
]
)
assert downsample_ratio < 1
reduction_factor = 1 / downsample_ratio
assert reduction_factor == 2.0
......@@ -441,15 +448,25 @@ class DynamicResolutionImageTiler:
patch_size: tuple[int, int]
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
resized_img = params.media.resize(
(
params.patch_size[0] * self._patch_size,
target_size = (
params.patch_size[1] * self._patch_size,
params.patch_size[0] * self._patch_size,
)
image = np.asarray(
params.media.convert("RGB") if params.media.mode != "RGB" else params.media,
dtype=np.uint8,
)
processed_images = [resized_img]
return [self._transform(img) for img in processed_images]
resized_img = (
torch.nn.functional.interpolate(
torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2),
size=target_size,
mode="bicubic",
align_corners=False,
antialias=True,
)
/ 255.0
)
return list(resized_img)
def process_media(
self,
......@@ -803,6 +820,7 @@ class BaseNanoNemotronVLProcessor(ABC):
image_repl = self.get_image_repl(feature_size, num_patches)
parts[i] = parts[i].replace("<image>", image_repl.full)
text = ["".join(parts)]
return text, image_inputs
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
......@@ -922,14 +940,14 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
frames_indices_lst = [
metadata["frames_indices"] for metadata in video_metadata_lst
]
video_num_patches = torch.tensor(
[len(item) for item in pixel_values_lst_video]
)
video_inputs = {
"pixel_values_flat_video": input_conditioner(
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
),
"video_num_patches": torch.tensor(
[len(item) for item in pixel_values_lst_video]
),
"video_num_patches": video_num_patches,
"frames_indices": frames_indices_lst,
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
}
......@@ -985,6 +1003,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
video_repl.full, skip_special_tokens=False
)
text = [t.replace("<video>", video_repl_text, 1) for t in text]
return text, video_inputs
def _preprocess_audio(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment