Unverified Commit d8827789 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Add vae test script (#404)

parent 9bb7dfe9
......@@ -273,7 +273,7 @@ def run_inference(
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
......@@ -491,8 +491,8 @@ def run_inference(
"clip_quant_scheme": clip_quant_scheme,
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"),
"use_tiling_vae": use_tiling_vae,
"use_tiny_vae": use_tiny_vae,
"tiny_vae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tiny_vae else None),
"use_tae": use_tae,
"tae_pth": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"lazy_load": lazy_load,
"do_mm_calib": False,
"parallel_attn_type": None,
......@@ -590,7 +590,7 @@ def auto_configure(enable_auto_config, resolution):
"t5_quant_scheme_val": "bf16",
"clip_quant_scheme_val": "fp16",
"precision_mode_val": "fp32",
"use_tiny_vae_val": False,
"use_tae_val": False,
"use_tiling_vae_val": False,
"enable_teacache_val": False,
"teacache_thresh_val": 0.26,
......@@ -692,7 +692,7 @@ def auto_configure(enable_auto_config, resolution):
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"use_tiny_vae_val": True,
"use_tae_val": True,
},
),
(
......@@ -713,7 +713,7 @@ def auto_configure(enable_auto_config, resolution):
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tiny_vae_val": True,
"use_tae_val": True,
},
),
]
......@@ -754,7 +754,7 @@ def auto_configure(enable_auto_config, resolution):
"unload_modules_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
"use_tiny_vae_val": True,
"use_tae_val": True,
}
if res == "540p"
else {
......@@ -770,7 +770,7 @@ def auto_configure(enable_auto_config, resolution):
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tiny_vae_val": True,
"use_tae_val": True,
}
),
),
......@@ -813,7 +813,7 @@ def auto_configure(enable_auto_config, resolution):
{
"t5_quant_scheme_val": quant_type,
"unload_modules_val": True,
"use_tiny_vae_val": True,
"use_tae_val": True,
},
),
]
......@@ -1161,7 +1161,7 @@ def main():
gr.Markdown("### Variational Autoencoder (VAE)")
with gr.Row():
use_tiny_vae = gr.Checkbox(
use_tae = gr.Checkbox(
label="Use Tiny VAE",
value=False,
info="Use a lightweight VAE model to accelerate the decoding process",
......@@ -1213,7 +1213,7 @@ def main():
t5_quant_scheme,
clip_quant_scheme,
precision_mode,
use_tiny_vae,
use_tae,
use_tiling_vae,
enable_teacache,
teacache_thresh,
......@@ -1248,7 +1248,7 @@ def main():
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
......@@ -1289,7 +1289,7 @@ def main():
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
......
......@@ -275,7 +275,7 @@ def run_inference(
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
......@@ -495,8 +495,8 @@ def run_inference(
"clip_quant_scheme": clip_quant_scheme,
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"),
"use_tiling_vae": use_tiling_vae,
"use_tiny_vae": use_tiny_vae,
"tiny_vae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tiny_vae else None),
"use_tae": use_tae,
"tae_pth": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"lazy_load": lazy_load,
"do_mm_calib": False,
"parallel_attn_type": None,
......@@ -594,7 +594,7 @@ def auto_configure(enable_auto_config, resolution):
"t5_quant_scheme_val": "bf16",
"clip_quant_scheme_val": "fp16",
"precision_mode_val": "fp32",
"use_tiny_vae_val": False,
"use_tae_val": False,
"use_tiling_vae_val": False,
"enable_teacache_val": False,
"teacache_thresh_val": 0.26,
......@@ -696,7 +696,7 @@ def auto_configure(enable_auto_config, resolution):
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"use_tiny_vae_val": True,
"use_tae_val": True,
},
),
(
......@@ -717,7 +717,7 @@ def auto_configure(enable_auto_config, resolution):
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tiny_vae_val": True,
"use_tae_val": True,
},
),
]
......@@ -758,7 +758,7 @@ def auto_configure(enable_auto_config, resolution):
"unload_modules_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
"use_tiny_vae_val": True,
"use_tae_val": True,
}
if res == "540p"
else {
......@@ -774,7 +774,7 @@ def auto_configure(enable_auto_config, resolution):
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tiny_vae_val": True,
"use_tae_val": True,
}
),
),
......@@ -817,7 +817,7 @@ def auto_configure(enable_auto_config, resolution):
{
"t5_quant_scheme_val": quant_type,
"unload_modules_val": True,
"use_tiny_vae_val": True,
"use_tae_val": True,
},
),
]
......@@ -1163,7 +1163,7 @@ def main():
gr.Markdown("### 变分自编码器(VAE)")
with gr.Row():
use_tiny_vae = gr.Checkbox(
use_tae = gr.Checkbox(
label="使用轻量级VAE",
value=False,
info="使用轻量级VAE模型加速解码过程",
......@@ -1215,7 +1215,7 @@ def main():
t5_quant_scheme,
clip_quant_scheme,
precision_mode,
use_tiny_vae,
use_tae,
use_tiling_vae,
enable_teacache,
teacache_thresh,
......@@ -1250,7 +1250,7 @@ def main():
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
......@@ -1291,7 +1291,7 @@ def main():
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
......
......@@ -23,7 +23,7 @@
"clip_quantized_ckpt": "/path/to/clip-fp8.pth",
"clip_quant_scheme": "fp8",
"use_tiling_vae": true,
"use_tiny_vae": true,
"tiny_vae_path": "/path/to/taew2_1.pth",
"use_tae": true,
"tae_pth": "/path/to/taew2_1.pth",
"lazy_load": true
}
......@@ -23,8 +23,8 @@
"clip_quantized_ckpt": "/path/to/clip-fp8.pth",
"clip_quant_scheme": "fp8",
"use_tiling_vae": true,
"use_tiny_vae": true,
"tiny_vae_path": "/path/to/taew2_1.pth",
"use_tae": true,
"tae_pth": "/path/to/taew2_1.pth",
"lazy_load": true,
"rotary_chunk": true,
"clean_cuda_cache": true
......
......@@ -28,8 +28,8 @@ In some cases, the VAE component can be time-consuming. You can use a lightweigh
```python
{
"use_tiny_vae": true,
"tiny_vae_path": "/path to taew2_1.pth"
"use_tae": true,
"tae_pth": "/path to taew2_1.pth"
}
```
The taew2_1.pth weights can be downloaded from [here](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)
......
......@@ -160,8 +160,8 @@ use_tiling_vae = True # Enable VAE chunked inference
```python
# VAE optimization configuration
use_tiny_vae = True # Use lightweight VAE
tiny_vae_path = "/path to taew2_1.pth"
use_tae = True # Use lightweight VAE
tae_pth = "/path to taew2_1.pth"
```
You can download taew2_1.pth [here](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth)
......
......@@ -28,8 +28,8 @@
```python
{
"use_tiny_vae": true,
"tiny_vae_path": "/path to taew2_1.pth"
"use_tae": true,
"tae_pth": "/path to taew2_1.pth"
}
```
taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载
......
......@@ -160,8 +160,8 @@ use_tiling_vae = True # 启用VAE分块推理
```python
# VAE优化配置
use_tiny_vae = True
tiny_vae_path = "/path to taew2_1.pth"
use_tae = True
tae_pth = "/path to taew2_1.pth"
```
taew2_1.pth 权重可以从[这里](https://github.com/madebyollin/taehv/raw/refs/heads/main/taew2_1.pth)下载
......
......@@ -178,16 +178,16 @@ class WanRunner(DefaultRunner):
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
}
if self.config.get("use_tiny_vae", False):
tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", self.tiny_vae_name)
vae_decoder = self.tiny_vae_cls(vae_pth=tiny_vae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
if self.config.get("use_tae", False):
tae_pth = find_torch_model_path(self.config, "tae_pth", self.tiny_vae_name)
vae_decoder = self.tiny_vae_cls(vae_pth=tae_pth, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
else:
vae_decoder = self.vae_cls(**vae_config)
return vae_decoder
def load_vae(self):
vae_encoder = self.load_vae_encoder()
if vae_encoder is None or self.config.get("use_tiny_vae", False):
if vae_encoder is None or self.config.get("use_tae", False):
vae_decoder = self.load_vae_decoder()
else:
vae_decoder = vae_encoder
......
......@@ -283,77 +283,3 @@ class TAEHV(nn.Module):
# (cogvideox seems to pad at the start?), but for multiple-of-4 it's fine.
return x
return x[:, self.frames_to_trim :]
@torch.no_grad()
def main():
"""Run TAEHV roundtrip reconstruction on the given video paths."""
import os
import sys
import cv2 # no highly esteemed deed is commemorated here
class VideoTensorReader:
def __init__(self, video_file_path):
self.cap = cv2.VideoCapture(video_file_path)
assert self.cap.isOpened(), f"Could not load {video_file_path}"
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
def __iter__(self):
return self
def __next__(self):
ret, frame = self.cap.read()
if not ret:
self.cap.release()
raise StopIteration # End of video or error
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
class VideoTensorWriter:
def __init__(self, video_file_path, width_height, fps=30):
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, width_height)
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
def write(self, frame_tensor):
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
def __del__(self):
if hasattr(self, "writer"):
self.writer.release()
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
dtype = torch.float16
checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
print(f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
for video_path in sys.argv[1:]:
print(f"Processing {video_path}...")
video_in = VideoTensorReader(video_path)
video = torch.stack(list(video_in), 0)[None]
vid_dev = video.to(dev, dtype).div_(255.0)
# convert to device tensor
if video.numel() < 100_000_000:
print(f" {video_path} seems small enough, will process all frames in parallel")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc)
print(f" Decoded {video_path} -> {vid_dec.shape}")
else:
print(f" {video_path} seems large, will process each frame sequentially")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev, parallel=False)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc, parallel=False)
print(f" Decoded {video_path} -> {vid_dec.shape}")
video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
video_out = VideoTensorWriter(video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
video_out.write(frame)
print(f" Saved to {video_out_path}")
if __name__ == "__main__":
main()
import argparse
import cv2
import torch
from loguru import logger
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny
class VideoTensorReader:
def __init__(self, video_file_path):
self.cap = cv2.VideoCapture(video_file_path)
assert self.cap.isOpened(), f"Could not load {video_file_path}"
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
def __iter__(self):
return self
def __next__(self):
ret, frame = self.cap.read()
if not ret:
self.cap.release()
raise StopIteration # End of video or error
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
class VideoTensorWriter:
def __init__(self, video_file_path, width_height, fps=30):
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, width_height)
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
def write(self, frame_tensor):
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
def __del__(self):
if hasattr(self, "writer"):
self.writer.release()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Encode and decode videos using the TaeHV model for reconstruction")
parser.add_argument("video_paths", nargs="+", help="Paths to input video files (multiple allowed)")
parser.add_argument("--checkpoint", "-c", help=f"Path to the model checkpoint file")
parser.add_argument("--device", "-d", default="cuda", help=f'Computing device (e.g., "cuda", "mps", "cpu"; default: auto-detect available device)')
parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float32"], help="Data type for model computation (default: bfloat16)")
parser.add_argument("--model_type", choices=["taew2_1", "taew2_2", "vaew2_1", "vaew2_2"], required=True, help="Type of the model to use (choices: taew2_1, taew2_2)")
parser.add_argument("--use_lightvae", default=False, action="store_true")
args = parser.parse_args()
if args.use_lightvae:
assert args.model_type in ["vaew2_1"]
if args.device:
dev = torch.device(args.device)
else:
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32}
model_map = {"taew2_1": WanVAE_tiny, "taew2_2": Wan2_2_VAE_tiny, "vaew2_1": WanVAE, "vaew2_2": Wan2_2_VAE}
dtype = dtype_map[args.dtype]
model_args = {"vae_pth": args.checkpoint, "dtype": dtype, "device": dev}
if args.model_type in "vaew2_1":
model_args.update({"use_lightvae": args.use_lightvae})
model = model_map[args.model_type](**model_args)
if args.model_type.startswith("tae"):
model = model_map[args.model_type](**model_args).to(dev)
# Process each input video
for idx, video_path in enumerate(args.video_paths):
logger.info(f"Processing video {video_path}...")
# Read video
video_in = VideoTensorReader(video_path)
video = torch.stack(list(video_in), 0)[None] # Add batch dimension
vid_dev = video.to(dev, dtype).div_(255.0) # Normalize to [0,1]
# Encode
vid_enc = model.encode_video(vid_dev)
if isinstance(vid_enc, tuple):
vid_enc = vid_enc[0]
# Decode
vid_dec = model.decode_video(vid_enc)
# Save reconstructed video
video_out_path = f"{video_path}.reconstructed_{idx}.mp4"
frame_size = (vid_dec.shape[-1], vid_dec.shape[-2])
fps = int(round(video_in.fps))
video_out = VideoTensorWriter(video_out_path, frame_size, fps)
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
video_out.write(frame)
logger.info(f" Reconstructed video saved to {video_out_path}")
......@@ -659,7 +659,7 @@ class WanVAE_(nn.Module):
return dec
def encode(self, x, scale):
def encode(self, x, scale, return_mu=False):
self.clear_cache()
## cache
t = x.shape[2]
......@@ -686,7 +686,10 @@ class WanVAE_(nn.Module):
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
if return_mu:
return mu, log_var
else:
return mu
def decode(self, z, scale):
self.clear_cache()
......@@ -722,12 +725,12 @@ class WanVAE_(nn.Module):
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
def sample(self, imgs, deterministic=False, scale=[0, 1]):
mu, log_var = self.encode(imgs, scale, return_mu=True)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
return mu + std * torch.randn_like(std), mu, log_var
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
......@@ -738,6 +741,24 @@ class WanVAE_(nn.Module):
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def encode_video(self, x, scale=[0, 1]):
assert x.ndim == 5 # NTCHW
assert x.shape[2] % 3 == 0
x = x.transpose(1, 2)
y = x.mul(2).sub_(1)
y, mu, log_var = self.sample(y, scale=scale)
return y.transpose(1, 2).to(x), mu, log_var
def decode_video(self, x, scale=[0, 1]):
assert x.ndim == 5 # NTCHW
assert x.shape[2] % self.z_dim == 0
x = x.transpose(1, 2)
# B, C, T, H, W
y = x
y = self.decode(y, scale).clamp_(-1, 1)
y = y.mul_(0.5).add_(0.5).clamp_(0, 1) # NCTHW
return y.transpose(1, 2).to(x)
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, load_from_rank0=False, pruning_rate=0.0, **kwargs):
"""
......@@ -1281,34 +1302,8 @@ class WanVAE:
return images
def encode_video(self, vid):
return self.model.encode_video(vid)
if __name__ == "__main__":
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
# # Test both 1D and 2D splitting
# print(f"Rank {dist.get_rank()}: Testing 1D splitting")
# model_1d = WanVAE(vae_pth="/data/nvme0/models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth", dtype=torch.bfloat16, parallel=True, use_2d_split=False)
# model_1d.to_cuda()
input_tensor = torch.randn(1, 3, 17, 480, 480).to(torch.bfloat16).to("cuda")
# encoded_tensor_1d = model_1d.encode(input_tensor)
# print(f"rank {dist.get_rank()} 1D encoded_tensor shape: {encoded_tensor_1d.shape}")
# decoded_tensor_1d = model_1d.decode(encoded_tensor_1d)
# print(f"rank {dist.get_rank()} 1D decoded_tensor shape: {decoded_tensor_1d.shape}")
print(f"Rank {dist.get_rank()}: Testing 2D splitting")
model_2d = WanVAE(vae_pth="/data/nvme0/models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth", dtype=torch.bfloat16, parallel=True, use_2d_split=True)
model_2d.to_cuda()
encoded_tensor_2d = model_2d.encode(input_tensor)
print(f"rank {dist.get_rank()} 2D encoded_tensor shape: {encoded_tensor_2d.shape}")
decoded_tensor_2d = model_2d.decode(encoded_tensor_2d)
print(f"rank {dist.get_rank()} 2D decoded_tensor shape: {decoded_tensor_2d.shape}")
# # Verify that both methods produce the same results
# if dist.get_rank() == 0:
# print(f"Encoded tensors match: {torch.allclose(encoded_tensor_1d, encoded_tensor_2d, atol=1e-5)}")
# print(f"Decoded tensors match: {torch.allclose(decoded_tensor_1d, decoded_tensor_2d, atol=1e-5)}")
dist.destroy_process_group()
def decode_video(self, vid_enc):
return self.model.decode_video(vid_enc)
......@@ -743,7 +743,7 @@ class WanVAE_(nn.Module):
x_recon = self.decode(mu, scale)
return x_recon, mu
def encode(self, x, scale):
def encode(self, x, scale, return_mu=False):
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
......@@ -769,7 +769,10 @@ class WanVAE_(nn.Module):
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
if return_mu:
return mu, log_var
else:
return mu
def decode(self, z, scale, offload_cache=False):
self.clear_cache()
......@@ -795,12 +798,12 @@ class WanVAE_(nn.Module):
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
def sample(self, imgs, deterministic=False, scale=[0, 1]):
mu, log_var = self.encode(imgs, scale, return_mu=True)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
return mu + std * torch.randn_like(std), mu, log_var
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
......@@ -811,6 +814,24 @@ class WanVAE_(nn.Module):
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def encode_video(self, x, scale=[0, 1]):
assert x.ndim == 5 # NTCHW
assert x.shape[2] % 3 == 0
x = x.transpose(1, 2)
y = x.mul(2).sub_(1)
y, mu, log_var = self.sample(y, scale=scale)
return y.transpose(1, 2).to(x), mu, log_var
def decode_video(self, x, scale=[0, 1]):
assert x.ndim == 5 # NTCHW
assert x.shape[2] % self.z_dim == 0
x = x.transpose(1, 2)
# B, C, T, H, W
y = x
y = self.decode(y, scale).clamp_(-1, 1)
y = y.mul_(0.5).add_(0.5).clamp_(0, 1) # NCTHW
return y.transpose(1, 2).to(x)
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offload=False, dtype=torch.float32, load_from_rank0=False, **kwargs):
# params
......@@ -1013,3 +1034,9 @@ class Wan2_2_VAE:
images = images.cpu().float()
self.to_cpu()
return images
def encode_video(self, vid):
return self.model.encode_video(vid)
def decode_video(self, vid_enc):
return self.model.decode_video(vid_enc)
import torch
import torch.nn as nn
from lightx2v.models.video_encoders.hf.tae import TAEHV
from lightx2v.utils.memory_profiler import peak_memory_decorator
from ..tae import TAEHV
class DotDict(dict):
__getattr__ = dict.__getitem__
......@@ -74,6 +73,14 @@ class WanVAE_tiny(nn.Module):
# low-memory, set parallel=True for faster + higher memory
return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel=False).transpose(1, 2).mul_(2).sub_(1)
@torch.no_grad()
def encode_video(self, vid):
return self.taehv.encode_video(vid)
@torch.no_grad()
def decode_video(self, vid_enc):
return self.taehv.decode_video(vid_enc)
class Wan2_2_VAE_tiny(nn.Module):
def __init__(self, vae_pth="taew2_2.pth", dtype=torch.bfloat16, device="cuda", need_scaled=False):
......@@ -199,3 +206,11 @@ class Wan2_2_VAE_tiny(nn.Module):
# low-memory, set parallel=True for faster + higher memory
return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel=False).transpose(1, 2).mul_(2).sub_(1)
@torch.no_grad()
def encode_video(self, vid):
return self.taehv.encode_video(vid)
@torch.no_grad()
def decode_video(self, vid_enc):
return self.taehv.decode_video(vid_enc)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment