Unverified Commit 270e8a41 authored by milesial's avatar milesial Committed by GitHub
Browse files

Nemotron Nano VL: Streamline pixel shuffle (#37580)


Signed-off-by: default avatarmilesial <milesial@users.noreply.github.com>
parent f44afef6
......@@ -1005,38 +1005,27 @@ class NemotronH_Nano_VL_V2(
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(
n,
w,
int(h * scale_factor),
int(c / scale_factor),
)
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale -->
# N, H * scale, W * scale, C // (scale ** 2)
x = x.view(
n,
int(h * scale_factor),
int(w * scale_factor),
int(c / (scale_factor * scale_factor)),
)
n, h, w, c = x.size()
r = int(1 / scale_factor)
new_h = h // r
new_w = w // r
new_c = c * r * r
x = x.view(n, new_h, r, new_w, r, c)
if self.ps_version == "v1":
warnings.warn(
"In ps_version 'v1', the height and width have not "
"been swapped back, which results in a transposed image.",
stacklevel=2,
)
x = x.permute(0, 3, 1, 2, 4, 5).reshape(n, new_w, new_h, new_c)
else:
x = x.permute(0, 2, 1, 3).contiguous()
x = x.permute(0, 1, 3, 2, 4, 5).reshape(n, new_h, new_w, new_c)
return x
def pixel_shuffle_dynamic_res(
self, x: torch.Tensor, *, imgs_sizes: list[tuple[int, int]]
) -> torch.Tensor:
scale_factor = self.downsample_ratio
patch_dim = self.patch_size
seq_lens = calc_seq_lens(imgs_sizes, patch_dim)
splits = torch.split(x, seq_lens, dim=-2)
......@@ -1045,22 +1034,8 @@ class NemotronH_Nano_VL_V2(
h = imgs_sizes[i][0] // patch_dim
w = imgs_sizes[i][1] // patch_dim
sv = sv.reshape(sv.shape[0], h, w, -1)
n, h, w, c = sv.size()
sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor))
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.view(
n,
int(w * scale_factor),
int(h * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v2":
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.reshape(sv.shape[0], -1, sv.shape[-1])
sv = self.pixel_shuffle(sv, scale_factor=self.downsample_ratio)
sv = sv.flatten(1, 2)
out.append(sv)
x = torch.cat(out, dim=-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment