Commit d061ae81 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fea] add approximate patch vae (#230)

parent fba9754a
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
"use_31_block": false, "use_31_block": false,
"adaptive_resize": true, "adaptive_resize": true,
"parallel": { "parallel": {
"vae_p_size": 4,
"seq_p_size": 4, "seq_p_size": 4,
"seq_p_attn_type": "ulysses" "seq_p_attn_type": "ulysses",
"use_patch_vae": false
} }
} }
...@@ -273,10 +273,10 @@ class VideoGenerator: ...@@ -273,10 +273,10 @@ class VideoGenerator:
_, nframe, height, width = self.model.scheduler.latents.shape _, nframe, height, width = self.model.scheduler.latents.shape
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype)
_, prev_mask = self._wan22_masks_like([self.model.scheduler.latents], zero=True, prev_length=prev_latents.shape[1]) _, prev_mask = self._wan22_masks_like([self.model.scheduler.latents], zero=True, prev_length=prev_latents.shape[1])
else: else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype))[0].to(dtype)
if prev_video is not None: if prev_video is not None:
prev_token_length = (prev_frame_length - 1) // 4 + 1 prev_token_length = (prev_frame_length - 1) // 4 + 1
...@@ -370,6 +370,7 @@ class VideoGenerator: ...@@ -370,6 +370,7 @@ class VideoGenerator:
# Decode latents # Decode latents
latents = self.model.scheduler.latents latents = self.model.scheduler.latents
generator = self.model.scheduler.generator generator = self.model.scheduler.generator
with ProfilingContext("Run VAE Decoder"):
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config) gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gen_video = torch.clamp(gen_video, -1, 1).to(torch.float) gen_video = torch.clamp(gen_video, -1, 1).to(torch.float)
...@@ -667,7 +668,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -667,7 +668,7 @@ class WanAudioRunner(WanRunner): # type:ignore
# vae encode # vae encode
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W") cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
vae_encoder_out = vae_model.encode(cond_frms.to(torch.float), config) vae_encoder_out = vae_model.encode(cond_frms.to(torch.float))
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
vae_encoder_out = vae_encoder_out.unsqueeze(0).to(GET_DTYPE()) vae_encoder_out = vae_encoder_out.unsqueeze(0).to(GET_DTYPE())
......
...@@ -135,7 +135,7 @@ class WanRunner(DefaultRunner): ...@@ -135,7 +135,7 @@ class WanRunner(DefaultRunner):
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"), "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"device": vae_device, "device": vae_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1, "parallel": self.config.parallel,
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
} }
...@@ -155,7 +155,7 @@ class WanRunner(DefaultRunner): ...@@ -155,7 +155,7 @@ class WanRunner(DefaultRunner):
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"), "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"device": vae_device, "device": vae_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1, "parallel": self.config.parallel,
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
} }
...@@ -313,7 +313,7 @@ class WanRunner(DefaultRunner): ...@@ -313,7 +313,7 @@ class WanRunner(DefaultRunner):
dim=1, dim=1,
).cuda() ).cuda()
vae_encoder_out = self.vae_encoder.encode([vae_input], self.config)[0] vae_encoder_out = self.vae_encoder.encode([vae_input])[0]
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder del self.vae_encoder
...@@ -497,5 +497,5 @@ class Wan22DenseRunner(WanRunner): ...@@ -497,5 +497,5 @@ class Wan22DenseRunner(WanRunner):
return vae_encoder_out return vae_encoder_out
def get_vae_encoder_output(self, img): def get_vae_encoder_output(self, img):
z = self.vae_encoder.encode(img, self.config) z = self.vae_encoder.encode(img)
return z return z
...@@ -37,7 +37,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I ...@@ -37,7 +37,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
config.lat_h = lat_h config.lat_h = lat_h
config.lat_w = lat_w config.lat_w = lat_w
vae_encoder_out = vae_model.encode([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1).cuda()], config)[0] vae_encoder_out = vae_model.encode([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1).cuda()])[0]
vae_encoder_out = vae_encoder_out.to(GET_DTYPE()) vae_encoder_out = vae_encoder_out.to(GET_DTYPE())
return vae_encoder_out return vae_encoder_out
...@@ -87,7 +87,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I ...@@ -87,7 +87,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
for i in range(n_iter): for i in range(n_iter):
if output_video is not None: # i !=0 if output_video is not None: # i !=0
prefix_video = output_video[:, :, -overlap_history:].to(self.model.scheduler.device) prefix_video = output_video[:, :, -overlap_history:].to(self.model.scheduler.device)
prefix_video = self.vae_model.encode(prefix_video, self.config)[0] # [(b, c, f, h, w)] prefix_video = self.vae_model.encode(prefix_video)[0] # [(b, c, f, h, w)]
if prefix_video.shape[1] % causal_block_size != 0: if prefix_video.shape[1] % causal_block_size != 0:
truncate_len = prefix_video.shape[1] % causal_block_size truncate_len = prefix_video.shape[1] % causal_block_size
# the length of prefix video is truncated for the casual block size alignment. # the length of prefix video is truncated for the casual block size alignment.
......
...@@ -7,6 +7,8 @@ import torch.nn.functional as F ...@@ -7,6 +7,8 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from loguru import logger from loguru import logger
from lightx2v.models.video_encoders.hf.wan.dist.distributed_env import DistributedEnv
from lightx2v.models.video_encoders.hf.wan.dist.split_gather import gather_forward_split_backward, split_forward_gather_backward
from lightx2v.utils.utils import load_weights from lightx2v.utils.utils import load_weights
__all__ = [ __all__ = [
...@@ -517,6 +519,7 @@ class WanVAE_(nn.Module): ...@@ -517,6 +519,7 @@ class WanVAE_(nn.Module):
self.temperal_downsample = temperal_downsample self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1] self.temperal_upsample = temperal_downsample[::-1]
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
self.use_approximate_patch = False
# The minimal tile height and width for spatial tiling to be used # The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256 self.tile_sample_min_height = 256
...@@ -547,6 +550,12 @@ class WanVAE_(nn.Module): ...@@ -547,6 +550,12 @@ class WanVAE_(nn.Module):
dropout, dropout,
) )
def enable_approximate_patch(self):
self.use_approximate_patch = True
def disable_approximate_patch(self):
self.use_approximate_patch = False
def forward(self, x): def forward(self, x):
mu, log_var = self.encode(x) mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var) z = self.reparameterize(mu, log_var)
...@@ -629,6 +638,9 @@ class WanVAE_(nn.Module): ...@@ -629,6 +638,9 @@ class WanVAE_(nn.Module):
return enc return enc
def tiled_decode(self, z, scale): def tiled_decode(self, z, scale):
if self.use_approximate_patch:
z = split_forward_gather_backward(None, z, 3)
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else: else:
...@@ -678,6 +690,8 @@ class WanVAE_(nn.Module): ...@@ -678,6 +690,8 @@ class WanVAE_(nn.Module):
result_rows.append(torch.cat(result_row, dim=-1)) result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if self.use_approximate_patch:
dec = gather_forward_split_backward(None, dec, 3)
return dec return dec
...@@ -686,7 +700,6 @@ class WanVAE_(nn.Module): ...@@ -686,7 +700,6 @@ class WanVAE_(nn.Module):
## cache ## cache
t = x.shape[2] t = x.shape[2]
iter_ = 1 + (t - 1) // 4 iter_ = 1 + (t - 1) // 4
## 对encode输入的x,按时间拆分为1、4、4、4....
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
if i == 0: if i == 0:
...@@ -707,11 +720,15 @@ class WanVAE_(nn.Module): ...@@ -707,11 +720,15 @@ class WanVAE_(nn.Module):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else: else:
mu = (mu - scale[0]) * scale[1] mu = (mu - scale[0]) * scale[1]
self.clear_cache() self.clear_cache()
return mu return mu
def decode(self, z, scale): def decode(self, z, scale):
self.clear_cache() self.clear_cache()
if self.use_approximate_patch:
z = split_forward_gather_backward(None, z, 3)
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
...@@ -734,6 +751,10 @@ class WanVAE_(nn.Module): ...@@ -734,6 +751,10 @@ class WanVAE_(nn.Module):
feat_idx=self._conv_idx, feat_idx=self._conv_idx,
) )
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
if self.use_approximate_patch:
out = gather_forward_split_backward(None, out, 3)
self.clear_cache() self.clear_cache()
return out return out
...@@ -845,6 +866,12 @@ class WanVAE: ...@@ -845,6 +866,12 @@ class WanVAE:
# init model # init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device) self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
self.use_approximate_patch = False
if self.parallel and self.parallel.get("use_patch_vae", False):
# assert not self.use_tiling
DistributedEnv.initialize(None)
self.use_approximate_patch = True
self.model.enable_approximate_patch()
def current_device(self): def current_device(self):
return next(self.model.parameters()).device return next(self.model.parameters()).device
...@@ -865,11 +892,11 @@ class WanVAE: ...@@ -865,11 +892,11 @@ class WanVAE:
self.inv_std = self.inv_std.cuda() self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def encode(self, videos, args): def encode(self, videos):
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
if hasattr(args, "cpu_offload") and args.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
if self.use_tiling: if self.use_tiling:
...@@ -877,7 +904,7 @@ class WanVAE: ...@@ -877,7 +904,7 @@ class WanVAE:
else: else:
out = [self.model.encode(u.unsqueeze(0).to(self.current_device()), self.scale).float().squeeze(0) for u in videos] out = [self.model.encode(u.unsqueeze(0).to(self.current_device()), self.scale).float().squeeze(0) for u in videos]
if hasattr(args, "cpu_offload") and args.cpu_offload: if self.cpu_offload:
self.to_cpu() self.to_cpu()
return out return out
...@@ -902,7 +929,8 @@ class WanVAE: ...@@ -902,7 +929,8 @@ class WanVAE:
elif split_dim == 3: elif split_dim == 3:
zs = zs[:, :, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size].contiguous() zs = zs[:, :, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size].contiguous()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
images = decode_func(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if cur_rank == 0: if cur_rank == 0:
if split_dim == 2: if split_dim == 2:
...@@ -933,23 +961,21 @@ class WanVAE: ...@@ -933,23 +961,21 @@ class WanVAE:
if self.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
if self.parallel: if self.parallel and not self.use_approximate_patch:
world_size = dist.get_world_size() world_size = dist.get_world_size()
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
height, width = zs.shape[2], zs.shape[3] height, width = zs.shape[2], zs.shape[3]
if width % world_size == 0: if width % world_size == 0:
split_dim = 3 images = self.decode_dist(zs, world_size, cur_rank, split_dim=3)
images = self.decode_dist(zs, world_size, cur_rank, split_dim)
elif height % world_size == 0: elif height % world_size == 0:
split_dim = 2 images = self.decode_dist(zs, world_size, cur_rank, split_dim=2)
images = self.decode_dist(zs, world_size, cur_rank, split_dim)
else: else:
logger.info("Fall back to naive decode mode") logger.info("Fall back to naive decode mode")
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
elif self.use_tiling:
images = self.model.tiled_decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
else: else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
images = decode_func(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if self.cpu_offload: if self.cpu_offload:
images = images.cpu().float() images = images.cpu().float()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment