Commit 375a6f77 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] support tae for wan_2_2 (#275)

parent f2e1def0
......@@ -23,7 +23,7 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
......@@ -34,6 +34,10 @@ from lightx2v.utils.utils import best_output_size, cache_video
class WanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
self.vae_cls = WanVAE
self.tiny_vae_cls = WanVAE_tiny
self.vae_name = "Wan2.1_VAE.pth"
self.tiny_vae_name = "taew2_1.pth"
def load_transformer(self):
model = WanModel(
......@@ -133,7 +137,7 @@ class WanRunner(DefaultRunner):
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
"device": vae_device,
"parallel": self.config.parallel,
"use_tiling": self.config.get("use_tiling_vae", False),
......@@ -143,7 +147,7 @@ class WanRunner(DefaultRunner):
if self.config.task not in ["i2v", "flf2v", "vace"]:
return None
else:
return WanVAE(**vae_config)
return self.vae_cls(**vae_config)
def load_vae_decoder(self):
# offload config
......@@ -154,7 +158,7 @@ class WanRunner(DefaultRunner):
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
"device": vae_device,
"parallel": self.config.parallel,
"use_tiling": self.config.get("use_tiling_vae", False),
......@@ -162,10 +166,10 @@ class WanRunner(DefaultRunner):
"dtype": GET_DTYPE(),
}
if self.config.get("use_tiny_vae", False):
tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
vae_decoder = WanVAE_tiny(vae_pth=tiny_vae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", self.tiny_vae_name)
vae_decoder = self.tiny_vae_cls(vae_pth=tiny_vae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
else:
vae_decoder = WanVAE(**vae_config)
vae_decoder = self.vae_cls(**vae_config)
return vae_decoder
def load_vae(self):
......@@ -430,47 +434,10 @@ class Wan22DenseRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
self.vae_encoder_need_img_original = True
def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
"dtype": GET_DTYPE(),
}
vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder
def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
"dtype": GET_DTYPE(),
}
if self.config.task not in ["i2v", "flf2v"]:
return None
else:
return Wan2_2_VAE(**vae_config)
def load_vae(self):
vae_encoder = self.load_vae_encoder()
vae_decoder = self.load_vae_decoder()
return vae_encoder, vae_decoder
self.vae_cls = Wan2_2_VAE
self.tiny_vae_cls = Wan2_2_VAE_tiny
self.vae_name = "Wan2.2_VAE.pth"
self.tiny_vae_name = "taew2_2.pth"
def run_vae_encoder(self, img):
max_area = self.config.target_height * self.config.target_width
......
import os
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
"""
from collections import namedtuple
import torch
......@@ -6,8 +11,6 @@ import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32,expandable_segments:True"
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
......@@ -149,27 +152,31 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
xt = b(xt)
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i + 1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
latent_channels = 16
image_channels = 3
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), patch_size=1, latent_channels=16):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
patch_size: input/output pixelshuffle patch-size for this model.
latent_channels: number of latent channels (z dim) for this model.
"""
super().__init__()
self.patch_size = patch_size
self.latent_channels = latent_channels
self.image_channels = 3
self.is_cogvideox = checkpoint_path is not None and "taecvx" in checkpoint_path
if checkpoint_path is not None and "taew2_2" in checkpoint_path:
self.patch_size, self.latent_channels = 2, 48
self.encoder = nn.Sequential(
conv(TAEHV.image_channels, 64),
conv(self.image_channels * self.patch_size**2, 64),
nn.ReLU(inplace=True),
TPool(64, 2),
conv(64, 64, stride=2, bias=False),
......@@ -186,13 +193,13 @@ class TAEHV(nn.Module):
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64),
conv(64, TAEHV.latent_channels),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(),
conv(TAEHV.latent_channels, n_f[0]),
conv(self.latent_channels, n_f[0]),
nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]),
MemBlock(n_f[0], n_f[0]),
......@@ -213,7 +220,7 @@ class TAEHV(nn.Module):
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True),
conv(n_f[3], TAEHV.image_channels),
conv(n_f[3], self.image_channels * self.patch_size**2),
)
if checkpoint_path is not None:
self.load_state_dict(self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)))
......@@ -243,6 +250,13 @@ class TAEHV(nn.Module):
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
if self.patch_size > 1:
x = F.pixel_unshuffle(x, self.patch_size)
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
n_pad = 4 - x.shape[1] % 4
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1)
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
def decode_video(self, x, parallel=True, show_progress_bar=True):
......@@ -255,16 +269,23 @@ class TAEHV(nn.Module):
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
skip_trim = self.is_cogvideox and x.shape[1] % 2 == 0
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
x = x.clamp_(0, 1)
if self.patch_size > 1:
x = F.pixel_shuffle(x, self.patch_size)
if skip_trim:
# skip trimming for cogvideox to make frame counts match.
# this still doesn't have correct temporal alignment for certain frame counts
# (cogvideox seems to pad at the start?), but for multiple-of-4 it's fine.
return x
return x[:, self.frames_to_trim :]
def forward(self, x):
return self.c(x)
@torch.no_grad()
def main():
"""Run TAEHV roundtrip reconstruction on the given video paths."""
import os
import sys
import cv2 # no highly esteemed deed is commemorated here
......@@ -300,8 +321,10 @@ def main():
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
dtype = torch.float16
print("Using device", dev, "and dtype", dtype)
taehv = TAEHV().to(dev, dtype)
checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
print(f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
for video_path in sys.argv[1:]:
print(f"Processing {video_path}...")
video_in = VideoTensorReader(video_path)
......@@ -322,7 +345,7 @@ def main():
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc, parallel=False)
print(f" Decoded {video_path} -> {vid_dec.shape}")
video_out_path = video_path + ".reconstructed_by_taehv.mp4"
video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
video_out = VideoTensorWriter(video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
video_out.write(frame)
......
......@@ -18,10 +18,8 @@ class WanVAE_tiny(nn.Module):
self.device = torch.device("cuda")
self.taehv = TAEHV(vae_pth).to(self.dtype)
self.temperal_downsample = [True, True, False]
self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
self.need_scaled = need_scaled
# temp
if self.need_scaled:
self.latents_mean = [
-0.7571,
......@@ -75,3 +73,129 @@ class WanVAE_tiny(nn.Module):
# low-memory, set parallel=True for faster + higher memory
return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel=False).transpose(1, 2).mul_(2).sub_(1)
class Wan2_2_VAE_tiny(nn.Module):
def __init__(self, vae_pth="taew2_2.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False):
super().__init__()
self.dtype = dtype
self.device = torch.device("cuda")
self.taehv = TAEHV(vae_pth).to(self.dtype)
self.need_scaled = need_scaled
if self.need_scaled:
self.latents_mean = [
-0.2289,
-0.0052,
-0.1323,
-0.2339,
-0.2799,
0.0174,
0.1838,
0.1557,
-0.1382,
0.0542,
0.2813,
0.0891,
0.1570,
-0.0098,
0.0375,
-0.1825,
-0.2246,
-0.1207,
-0.0698,
0.5109,
0.2665,
-0.2108,
-0.2158,
0.2502,
-0.2055,
-0.0322,
0.1109,
0.1567,
-0.0729,
0.0899,
-0.2799,
-0.1230,
-0.0313,
-0.1649,
0.0117,
0.0723,
-0.2839,
-0.2083,
-0.0520,
0.3748,
0.0152,
0.1957,
0.1433,
-0.2944,
0.3573,
-0.0548,
-0.1681,
-0.0667,
]
self.latents_std = [
0.4765,
1.0364,
0.4514,
1.1677,
0.5313,
0.4990,
0.4818,
0.5013,
0.8158,
1.0344,
0.5894,
1.0901,
0.6885,
0.6165,
0.8454,
0.4978,
0.5759,
0.3523,
0.7135,
0.6804,
0.5833,
1.4146,
0.8986,
0.5659,
0.7069,
0.5338,
0.4889,
0.4917,
0.4069,
0.4999,
0.6866,
0.4093,
0.5709,
0.6065,
0.6415,
0.4944,
0.5726,
1.2042,
0.5458,
1.6887,
0.3971,
1.0600,
0.3943,
0.5537,
0.5444,
0.4089,
0.7468,
0.7744,
]
self.z_dim = 48
@peak_memory_decorator
@torch.no_grad()
def decode(self, latents):
latents = latents.unsqueeze(0)
if self.need_scaled:
latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
# low-memory, set parallel=True for faster + higher memory
return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel=False).transpose(1, 2).mul_(2).sub_(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment