Unverified Commit a0d487b2 authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

nano_nemotron_vl: suppress readonly torch.from_numpy() warning in image and...


nano_nemotron_vl: suppress readonly torch.from_numpy() warning in image and video resize paths (#37903)
Signed-off-by: default avatarNetanel Haber <58652339+netanel-haber@users.noreply.github.com>
parent b73b5b06
......@@ -8,6 +8,7 @@
# --------------------------------------------------------
import math
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
......@@ -66,6 +67,30 @@ def input_conditioner(x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.
return (x - norm_mean) / norm_std
def _bicubic_from_ndarray(
array: npt.NDArray[Any], *, size: tuple[int, int]
) -> torch.Tensor:
"""
Convert a 4D NHWC ndarray to NCHW and interpolate with bicubic.
Suppresses PyTorch's non-writable NumPy warning because interpolate copies,
and torch.from_numpy(array) is discarded at the end of function scope.
"""
with warnings.catch_warnings():
msg = "The given NumPy array is not writ.*"
# Apparently, different versions of PyTorch use writable or writeable.
warnings.filterwarnings("ignore", message=msg, category=UserWarning)
tensor = torch.from_numpy(array)
assert tensor.ndim == 4, f"{tensor.ndim=}"
tensor = tensor.permute(0, 3, 1, 2)
return (
torch.nn.functional.interpolate(
tensor, size=size, mode="bicubic", align_corners=False, antialias=True
)
/ 255.0
)
def dynamic_preprocess(
image,
*,
......@@ -90,36 +115,19 @@ def dynamic_preprocess(
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
)
image = torch.from_numpy(image).unsqueeze(0) # (1, H, W, 3)
image = image.permute(0, 3, 1, 2) # (1, 3, H, W)
image = np.expand_dims(image, axis=0)
resized_img = torch.nn.functional.interpolate(
image,
size=(target_height, target_width),
mode="bicubic",
align_corners=False,
antialias=True,
)
resized_img = _bicubic_from_ndarray(image, size=(target_height, target_width))
B, C, H, W = resized_img.shape
hp, wp = H // image_size, W // image_size
patches = (
resized_img.reshape(B, C, hp, image_size, wp, image_size)
.permute(0, 2, 4, 1, 3, 5)
.reshape(B * hp * wp, C, image_size, image_size)
/ 255.0
)
if use_thumbnail and patches.shape[0] > 1:
thumb = (
torch.nn.functional.interpolate(
image,
size=(image_size, image_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
/ 255.0
)
thumb = _bicubic_from_ndarray(image, size=(image_size, image_size))
patches = torch.cat([patches, thumb], dim=0)
return list(patches)
......@@ -241,21 +249,9 @@ def video_to_pixel_values(
downsample_ratio=downsample_ratio,
)
if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w:
video_tensor = torch.nn.functional.interpolate(
video_tensor,
size=(target_h, target_w),
mode="bicubic",
align_corners=False,
antialias=True,
)
return _bicubic_from_ndarray(video, size=(target_h, target_w))
elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
video_tensor = torch.nn.functional.interpolate(
video_tensor,
size=(input_size, input_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
return _bicubic_from_ndarray(video, size=(input_size, input_size))
video_tensor = video_tensor / 255.0
......@@ -385,16 +381,8 @@ class DynamicResolutionImageTiler:
params.media.convert("RGB") if params.media.mode != "RGB" else params.media,
dtype=np.uint8,
)
resized_img = (
torch.nn.functional.interpolate(
torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2),
size=target_size,
mode="bicubic",
align_corners=False,
antialias=True,
)
/ 255.0
)
image = np.expand_dims(image, axis=0)
resized_img = _bicubic_from_ndarray(image, size=target_size)
return list(resized_img)
def process_media(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment