"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "2bb2bdd5d41af52a19d34cf3ee5d4148839562e5"
Unverified Commit 58d2b10a authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[wan2.2] fix vae patches (#12041)

up
parent 20e0740b
......@@ -913,38 +913,21 @@ def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
# x shape: [batch_size, channels, height, width]
batch_size, channels, height, width = x.shape
# Ensure height and width are divisible by patch_size
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
# Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
# Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size)
elif x.dim() == 5:
# x shape: [batch_size, channels, frames, height, width]
batch_size, channels, frames, height, width = x.shape
# Ensure height and width are divisible by patch_size
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
if x.dim() != 5:
raise ValueError(f"Invalid input shape: {x.shape}")
# x shape: [batch_size, channels, frames, height, width]
batch_size, channels, frames, height, width = x.shape
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
# Ensure height and width are divisible by patch_size
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
return x
......@@ -953,29 +936,18 @@ def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
# x shape: [b, (c * patch_size * patch_size), h, w]
batch_size, c_patches, height, width = x.shape
channels = c_patches // (patch_size * patch_size)
# Reshape to [b, c, patch_size, patch_size, h, w]
x = x.view(batch_size, channels, patch_size, patch_size, height, width)
# Rearrange to [b, c, h * patch_size, w * patch_size]
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(batch_size, channels, height * patch_size, width * patch_size)
elif x.dim() == 5:
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
batch_size, c_patches, frames, height, width = x.shape
channels = c_patches // (patch_size * patch_size)
if x.dim() != 5:
raise ValueError(f"Invalid input shape: {x.shape}")
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
batch_size, c_patches, frames, height, width = x.shape
channels = c_patches // (patch_size * patch_size)
# Reshape to [b, c, patch_size, patch_size, f, h, w]
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
# Reshape to [b, c, patch_size, patch_size, f, h, w]
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous()
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
return x
......@@ -1044,7 +1016,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
patch_size: Optional[int] = None,
scale_factor_temporal: Optional[int] = 4,
scale_factor_spatial: Optional[int] = 8,
clip_output: bool = True,
) -> None:
super().__init__()
......@@ -1244,10 +1215,11 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
if self.config.clip_output:
out = torch.clamp(out, min=-1.0, max=1.0)
if self.config.patch_size is not None:
out = unpatchify(out, patch_size=self.config.patch_size)
out = torch.clamp(out, min=-1.0, max=1.0)
self.clear_cache()
if not return_dict:
return (out,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment