Commit 8d32295d authored by wangshankun's avatar wangshankun
Browse files

Merge branch 'main' of https://github.com/ModelTC/LightX2V into main

parents daa06243 1994ffb1
{
"infer_steps": 50,
"target_video_length": 121,
"text_len": 512,
"target_height": 704,
"target_width": 1280,
"num_channels_latents": 48,
"vae_stride": [4, 16, 16],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0,
"sample_shift": 5.0,
"enable_cfg": true,
"fps": 24,
"use_image_encoder": false,
"cpu_offload": true,
"offload_granularity": "model",
"vae_offload_cache": true
}
{
"infer_steps": 50,
"target_video_length": 121,
"text_len": 512,
"target_height": 704,
"target_width": 1280,
"num_channels_latents": 48,
"vae_stride": [4, 16, 16],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0,
"sample_shift": 5.0,
"enable_cfg": true,
"fps": 24,
"cpu_offload": true,
"offload_granularity": "model",
"vae_offload_cache": true
}
......@@ -68,13 +68,13 @@ class WanRunner(DefaultRunner):
assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme)
clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
clip_original_ckpt = None
else:
clip_quantized_ckpt = None
clip_quant_scheme = None
clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original")
clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name)
image_encoder = CLIPModel(
dtype=torch.float16,
......@@ -90,7 +90,7 @@ class WanRunner(DefaultRunner):
def load_text_encoder(self):
# offload config
t5_offload = self.config.get("t5_cpu_offload", False)
t5_offload = self.config.get("t5_cpu_offload", self.config.get("cpu_offload"))
if t5_offload:
t5_device = torch.device("cpu")
else:
......@@ -103,14 +103,14 @@ class WanRunner(DefaultRunner):
assert t5_quant_scheme is not None
tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme)
t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name)
t5_original_ckpt = None
tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
else:
t5_quant_scheme = None
t5_quantized_ckpt = None
t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original")
t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name)
tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
text_encoder = T5EncoderModel(
......@@ -121,7 +121,7 @@ class WanRunner(DefaultRunner):
tokenizer_path=tokenizer_path,
shard_fn=None,
cpu_offload=t5_offload,
offload_granularity=self.config.get("t5_offload_granularity", "model"),
offload_granularity=self.config.get("t5_offload_granularity", "model"), # support ['model', 'block']
t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme,
......@@ -131,12 +131,20 @@ class WanRunner(DefaultRunner):
return text_encoders
def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"device": self.init_device,
"device": vae_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False),
"seq_p_group": self.seq_p_group,
"cpu_offload": vae_offload,
}
if self.config.task != "i2v":
return None
......@@ -144,11 +152,19 @@ class WanRunner(DefaultRunner):
return WanVAE(**vae_config)
def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"device": self.init_device,
"device": vae_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
}
if self.config.get("use_tiny_vae", False):
tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
......@@ -398,17 +414,33 @@ class Wan22DenseRunner(WanRunner):
super().__init__(config)
def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": self.init_device,
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder
def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": self.init_device,
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
if self.config.task != "i2v":
return None
......
......@@ -797,11 +797,13 @@ class WanVAE:
parallel=False,
use_tiling=False,
seq_p_group=None,
cpu_offload=False,
):
self.dtype = dtype
self.device = device
self.parallel = parallel
self.use_tiling = use_tiling
self.cpu_offload = cpu_offload
mean = [
-0.7571,
......@@ -938,8 +940,8 @@ class WanVAE:
return images
def decode(self, zs, generator, config):
if config.cpu_offload:
def decode(self, zs, **args):
if self.cpu_offload:
self.to_cuda()
if self.parallel:
......@@ -960,7 +962,7 @@ class WanVAE:
else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if config.cpu_offload:
if self.cpu_offload:
images = images.cpu().float()
self.to_cpu()
......
......@@ -619,7 +619,7 @@ class Decoder3d(nn.Module):
CausalConv3d(out_dim, 12, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, offload_cache=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
......@@ -639,14 +639,24 @@ class Decoder3d(nn.Module):
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
idx = feat_idx[0]
x = layer(x, feat_cache, feat_idx)
if offload_cache:
for _idx in range(idx, feat_idx[0]):
if isinstance(feat_cache[_idx], torch.Tensor):
feat_cache[_idx] = feat_cache[_idx].cpu()
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
idx = feat_idx[0]
x = layer(x, feat_cache, feat_idx, first_chunk)
if offload_cache:
for _idx in range(idx, feat_idx[0]):
if isinstance(feat_cache[_idx], torch.Tensor):
feat_cache[_idx] = feat_cache[_idx].cpu()
else:
x = layer(x)
......@@ -664,7 +674,7 @@ class Decoder3d(nn.Module):
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_cache[idx] = cache_x.cpu() if offload_cache else cache_x
feat_idx[0] += 1
else:
x = layer(x)
......@@ -755,7 +765,7 @@ class WanVAE_(nn.Module):
self.clear_cache()
return mu
def decode(self, z, scale):
def decode(self, z, scale, offload_cache=False):
self.clear_cache()
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
......@@ -766,18 +776,9 @@ class WanVAE_(nn.Module):
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i : i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True, offload_cache=offload_cache)
else:
out_ = self.decoder(
x[:, :, i : i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, offload_cache=offload_cache)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
......@@ -830,18 +831,11 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
class Wan2_2_VAE:
def __init__(
self,
z_dim=48,
c_dim=160,
vae_pth=None,
dim_mult=[1, 2, 4, 4],
temperal_downsample=[False, True, True],
dtype=torch.float,
device="cuda",
):
def __init__(self, z_dim=48, c_dim=160, vae_pth=None, dim_mult=[1, 2, 4, 4], temperal_downsample=[False, True, True], dtype=torch.float, device="cuda", cpu_offload=False, offload_cache=False):
self.dtype = dtype
self.device = device
self.cpu_offload = cpu_offload
self.offload_cache = offload_cache
self.mean = torch.tensor(
[
......@@ -991,11 +985,11 @@ class Wan2_2_VAE:
self.to_cpu()
return out
def decode(self, zs, generator, config):
if config.cpu_offload:
def decode(self, zs, **args):
if self.cpu_offload:
self.to_cuda()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if config.cpu_offload:
images = self.model.decode(zs.unsqueeze(0), self.scale, offload_cache=self.offload_cache if self.cpu_offload else False).float().clamp_(-1, 1)
if self.cpu_offload:
images = images.cpu().float()
self.to_cpu()
return images
......@@ -258,7 +258,7 @@ def save_to_video(
raise ValueError(f"Unknown save method: {method}")
def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["original", "fp8", "int8"]):
def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["original", "fp8", "int8", "distill_models", "distill_fp8", "distill_int8"]):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key)
......@@ -277,7 +277,7 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["original", "fp8", "int8"]):
def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["original", "fp8", "int8", "distill_models", "distill_fp8", "distill_int8"]):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key)
......
......@@ -5,7 +5,7 @@ lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment