Commit 9e120289 authored by wangshankun's avatar wangshankun
Browse files

Merge branch 'main' of https://github.com/ModelTC/LightX2V into main

parents b5bcbed7 9196a220
import argparse
import json
import torch.distributed as dist
from loguru import logger
......@@ -16,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import set_config, set_parallel_config
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v.utils.utils import seed_all
......@@ -40,18 +39,31 @@ def main():
type=str,
required=True,
choices=[
"wan2.1",
"hunyuan",
"wan2.1_distill",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"cogvideox",
"wan2.1_audio",
"wan2.2_moe",
"wan2.2",
"wan2.2_moe_audio",
"wan2.2_audio",
"wan2.2",
"wan2.2_moe_distill",
,
],
default="wan2.1",
)
......@@ -70,17 +82,16 @@ def main():
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args()
logger.info(f"args: {args}")
# set config
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
if config.parallel:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
set_parallel_config(config)
print_config(config)
with ProfilingContext("Total Cost"):
runner = init_runner(config)
runner.run_pipeline()
......
......@@ -540,7 +540,6 @@ class T5EncoderModel:
t5_quantized=False,
t5_quantized_ckpt=None,
quant_scheme=None,
seq_p_group=None,
):
self.text_len = text_len
self.dtype = dtype
......
......@@ -418,13 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, seq_p_group=None):
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True):
self.dtype = dtype
self.device = device
self.quantized = clip_quantized
self.cpu_offload = cpu_offload
self.use_31_block = use_31_block
self.seq_p_group = seq_p_group
if self.quantized:
self.checkpoint_path = clip_quantized_ckpt
......
......@@ -16,8 +16,8 @@ class WanAudioModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device, seq_p_group)
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _init_infer_class(self):
super()._init_infer_class()
......
......@@ -23,8 +23,8 @@ class WanCausVidModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device, seq_p_group)
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
......
......@@ -19,8 +19,8 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device, seq_p_group)
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
......
......@@ -74,10 +74,6 @@ class WanPreInfer:
x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
# wan2.2_moe会对t做扩展,我们发现这里做不做影响不大,而且做了拓展会增加耗时,目前忠实原作代码,后续可以考虑去掉
if self.config["model_cls"] == "wan2.2_moe":
t = t.expand(seq_lens[0])
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg:
s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device)
......
......@@ -41,12 +41,16 @@ class WanModel:
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None):
def __init__(self, model_path, config, device):
self.model_path = model_path
self.config = config
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.seq_p_group = seq_p_group
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
......@@ -252,8 +256,6 @@ class WanModel:
if target_device == "cuda":
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
for key in sorted(synced_meta_dict.keys()):
if is_weight_loader:
......@@ -390,11 +392,11 @@ class WanModel:
x = F.pad(x, (0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"].startswith("wan2.2"):
padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size
if padding_size > 0:
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
embed = F.pad(embed, (0, 0, 0, padding_size))
# if self.config["model_cls"] == "wan2.2":
# padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size
# if padding_size > 0:
# embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
# embed = F.pad(embed, (0, 0, 0, padding_size))
pre_infer_out.x = x
pre_infer_out.embed = embed
......
......@@ -343,7 +343,7 @@ class VideoGenerator:
self.model.scheduler.reset()
inputs["previmg_encoder_output"] = self.prepare_prev_latents(prev_video, prev_frame_length)
# Run inference loop
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
......
......@@ -29,7 +29,6 @@ class WanCausVidRunner(WanRunner):
self.config.model_path,
self.config,
self.init_device,
self.seq_p_group,
)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
......
......@@ -21,7 +21,6 @@ class WanDistillRunner(WanRunner):
self.config.model_path,
self.config,
self.init_device,
self.seq_p_group,
)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
......@@ -91,7 +90,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "high_noise_model"),
self.config,
self.init_device,
self.seq_p_group,
)
high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config.lora_configs:
......@@ -106,7 +104,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "distill_models", "high_noise_model"),
self.config,
self.init_device,
self.seq_p_group,
)
if use_low_lora:
......
......@@ -34,18 +34,12 @@ from lightx2v.utils.utils import best_output_size, cache_video
class WanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
device_mesh = self.config.get("device_mesh")
if device_mesh is not None:
self.seq_p_group = device_mesh.get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
def load_transformer(self):
model = WanModel(
self.config.model_path,
self.config,
self.init_device,
self.seq_p_group,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
......@@ -83,7 +77,6 @@ class WanRunner(DefaultRunner):
clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme,
seq_p_group=self.seq_p_group,
cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)),
use_31_block=self.config.get("use_31_block", True),
)
......@@ -127,7 +120,6 @@ class WanRunner(DefaultRunner):
t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme,
seq_p_group=self.seq_p_group,
)
text_encoders = [text_encoder]
return text_encoders
......@@ -145,7 +137,6 @@ class WanRunner(DefaultRunner):
"device": vae_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False),
"seq_p_group": self.seq_p_group,
"cpu_offload": vae_offload,
}
if self.config.task != "i2v":
......
......@@ -60,12 +60,9 @@ class WanScheduler(BaseScheduler):
device=self.device,
generator=self.generator,
)
if self.config["model_cls"] == "wan2.2":
if self.config["task"] == "t2v":
self.mask = masks_like(self.latents, zero=False)
elif self.config["task"] == "i2v":
self.mask = masks_like(self.latents, zero=True)
self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
self.mask = masks_like(self.latents, zero=True)
self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents
def set_timesteps(
self,
......
......@@ -759,7 +759,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, cpu_offload=False, **kwargs):
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
......@@ -795,7 +795,6 @@ class WanVAE:
device="cuda",
parallel=False,
use_tiling=False,
seq_p_group=None,
cpu_offload=False,
):
self.dtype = dtype
......@@ -845,7 +844,7 @@ class WanVAE:
self.scale = [self.mean, self.inv_std]
# init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, seq_p_group=seq_p_group, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
def current_device(self):
return next(self.model.parameters()).device
......
......@@ -69,9 +69,6 @@ def set_config(args):
def set_parallel_config(config):
if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cfg_p_size = config.parallel.get("cfg_p_size", 1)
seq_p_size = config.parallel.get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
......@@ -82,3 +79,13 @@ def set_parallel_config(config):
if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1:
config["cfg_parallel"] = True
def print_config(config):
config_to_print = config.copy()
config_to_print.pop("device_mesh", None)
if config.parallel:
if dist.get_rank() == 0:
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
else:
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment