Unverified Commit a0d487b2 authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

nano_nemotron_vl: suppress readonly torch.from_numpy() warning in image and...


nano_nemotron_vl: suppress readonly torch.from_numpy() warning in image and video resize paths (#37903)
Signed-off-by: default avatarNetanel Haber <58652339+netanel-haber@users.noreply.github.com>
parent b73b5b06
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# -------------------------------------------------------- # --------------------------------------------------------
import math import math
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
...@@ -66,6 +67,30 @@ def input_conditioner(x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch. ...@@ -66,6 +67,30 @@ def input_conditioner(x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.
return (x - norm_mean) / norm_std return (x - norm_mean) / norm_std
def _bicubic_from_ndarray(
array: npt.NDArray[Any], *, size: tuple[int, int]
) -> torch.Tensor:
"""
Convert a 4D NHWC ndarray to NCHW and interpolate with bicubic.
Suppresses PyTorch's non-writable NumPy warning because interpolate copies,
and torch.from_numpy(array) is discarded at the end of function scope.
"""
with warnings.catch_warnings():
msg = "The given NumPy array is not writ.*"
# Apparently, different versions of PyTorch use writable or writeable.
warnings.filterwarnings("ignore", message=msg, category=UserWarning)
tensor = torch.from_numpy(array)
assert tensor.ndim == 4, f"{tensor.ndim=}"
tensor = tensor.permute(0, 3, 1, 2)
return (
torch.nn.functional.interpolate(
tensor, size=size, mode="bicubic", align_corners=False, antialias=True
)
/ 255.0
)
def dynamic_preprocess( def dynamic_preprocess(
image, image,
*, *,
...@@ -90,36 +115,19 @@ def dynamic_preprocess( ...@@ -90,36 +115,19 @@ def dynamic_preprocess(
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8 image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
) )
image = torch.from_numpy(image).unsqueeze(0) # (1, H, W, 3) image = np.expand_dims(image, axis=0)
image = image.permute(0, 3, 1, 2) # (1, 3, H, W)
resized_img = torch.nn.functional.interpolate( resized_img = _bicubic_from_ndarray(image, size=(target_height, target_width))
image,
size=(target_height, target_width),
mode="bicubic",
align_corners=False,
antialias=True,
)
B, C, H, W = resized_img.shape B, C, H, W = resized_img.shape
hp, wp = H // image_size, W // image_size hp, wp = H // image_size, W // image_size
patches = ( patches = (
resized_img.reshape(B, C, hp, image_size, wp, image_size) resized_img.reshape(B, C, hp, image_size, wp, image_size)
.permute(0, 2, 4, 1, 3, 5) .permute(0, 2, 4, 1, 3, 5)
.reshape(B * hp * wp, C, image_size, image_size) .reshape(B * hp * wp, C, image_size, image_size)
/ 255.0
) )
if use_thumbnail and patches.shape[0] > 1: if use_thumbnail and patches.shape[0] > 1:
thumb = ( thumb = _bicubic_from_ndarray(image, size=(image_size, image_size))
torch.nn.functional.interpolate(
image,
size=(image_size, image_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
/ 255.0
)
patches = torch.cat([patches, thumb], dim=0) patches = torch.cat([patches, thumb], dim=0)
return list(patches) return list(patches)
...@@ -241,21 +249,9 @@ def video_to_pixel_values( ...@@ -241,21 +249,9 @@ def video_to_pixel_values(
downsample_ratio=downsample_ratio, downsample_ratio=downsample_ratio,
) )
if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w: if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w:
video_tensor = torch.nn.functional.interpolate( return _bicubic_from_ndarray(video, size=(target_h, target_w))
video_tensor,
size=(target_h, target_w),
mode="bicubic",
align_corners=False,
antialias=True,
)
elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size: elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
video_tensor = torch.nn.functional.interpolate( return _bicubic_from_ndarray(video, size=(input_size, input_size))
video_tensor,
size=(input_size, input_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
video_tensor = video_tensor / 255.0 video_tensor = video_tensor / 255.0
...@@ -385,16 +381,8 @@ class DynamicResolutionImageTiler: ...@@ -385,16 +381,8 @@ class DynamicResolutionImageTiler:
params.media.convert("RGB") if params.media.mode != "RGB" else params.media, params.media.convert("RGB") if params.media.mode != "RGB" else params.media,
dtype=np.uint8, dtype=np.uint8,
) )
resized_img = ( image = np.expand_dims(image, axis=0)
torch.nn.functional.interpolate( resized_img = _bicubic_from_ndarray(image, size=target_size)
torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2),
size=target_size,
mode="bicubic",
align_corners=False,
antialias=True,
)
/ 255.0
)
return list(resized_img) return list(resized_img)
def process_media( def process_media(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment