Unverified Commit 270e8a41 authored by milesial's avatar milesial Committed by GitHub
Browse files

Nemotron Nano VL: Streamline pixel shuffle (#37580)


Signed-off-by: default avatarmilesial <milesial@users.noreply.github.com>
parent f44afef6
...@@ -1005,38 +1005,27 @@ class NemotronH_Nano_VL_V2( ...@@ -1005,38 +1005,27 @@ class NemotronH_Nano_VL_V2(
) )
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, h, w, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale r = int(1 / scale_factor)
x = x.view( new_h = h // r
n, new_w = w // r
w, new_c = c * r * r
int(h * scale_factor),
int(c / scale_factor), x = x.view(n, new_h, r, new_w, r, c)
)
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale -->
# N, H * scale, W * scale, C // (scale ** 2)
x = x.view(
n,
int(h * scale_factor),
int(w * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v1": if self.ps_version == "v1":
warnings.warn( warnings.warn(
"In ps_version 'v1', the height and width have not " "In ps_version 'v1', the height and width have not "
"been swapped back, which results in a transposed image.", "been swapped back, which results in a transposed image.",
stacklevel=2, stacklevel=2,
) )
x = x.permute(0, 3, 1, 2, 4, 5).reshape(n, new_w, new_h, new_c)
else: else:
x = x.permute(0, 2, 1, 3).contiguous() x = x.permute(0, 1, 3, 2, 4, 5).reshape(n, new_h, new_w, new_c)
return x return x
def pixel_shuffle_dynamic_res( def pixel_shuffle_dynamic_res(
self, x: torch.Tensor, *, imgs_sizes: list[tuple[int, int]] self, x: torch.Tensor, *, imgs_sizes: list[tuple[int, int]]
) -> torch.Tensor: ) -> torch.Tensor:
scale_factor = self.downsample_ratio
patch_dim = self.patch_size patch_dim = self.patch_size
seq_lens = calc_seq_lens(imgs_sizes, patch_dim) seq_lens = calc_seq_lens(imgs_sizes, patch_dim)
splits = torch.split(x, seq_lens, dim=-2) splits = torch.split(x, seq_lens, dim=-2)
...@@ -1045,22 +1034,8 @@ class NemotronH_Nano_VL_V2( ...@@ -1045,22 +1034,8 @@ class NemotronH_Nano_VL_V2(
h = imgs_sizes[i][0] // patch_dim h = imgs_sizes[i][0] // patch_dim
w = imgs_sizes[i][1] // patch_dim w = imgs_sizes[i][1] // patch_dim
sv = sv.reshape(sv.shape[0], h, w, -1) sv = sv.reshape(sv.shape[0], h, w, -1)
sv = self.pixel_shuffle(sv, scale_factor=self.downsample_ratio)
n, h, w, c = sv.size() sv = sv.flatten(1, 2)
sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor))
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.view(
n,
int(w * scale_factor),
int(h * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v2":
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.reshape(sv.shape[0], -1, sv.shape[-1])
out.append(sv) out.append(sv)
x = torch.cat(out, dim=-2) x = torch.cat(out, dim=-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment