Commit 9e120289 authored by wangshankun's avatar wangshankun
Browse files

Merge branch 'main' of https://github.com/ModelTC/LightX2V into main

parents b5bcbed7 9196a220
import argparse import argparse
import json
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
...@@ -16,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D ...@@ -16,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import set_config, set_parallel_config from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v.utils.utils import seed_all from lightx2v.utils.utils import seed_all
...@@ -40,18 +39,31 @@ def main(): ...@@ -40,18 +39,31 @@ def main():
type=str, type=str,
required=True, required=True,
choices=[ choices=[
"wan2.1", "wan2.1",
"hunyuan", "hunyuan",
"wan2.1_distill", "wan2.1_distill",
"wan2.1_causvid", "wan2.1_causvid",
"wan2.1_skyreels_v2_df", "wan2.1_skyreels_v2_df",
"cogvideox", "cogvideox",
"wan2.1_audio", "wan2.1_audio",
"wan2.2_moe", "wan2.2_moe",
"wan2.2",
"wan2.2_moe_audio", "wan2.2_moe_audio",
"wan2.2_audio", "wan2.2_audio",
"wan2.2",
"wan2.2_moe_distill", "wan2.2_moe_distill",
,
], ],
default="wan2.1", default="wan2.1",
) )
...@@ -70,17 +82,16 @@ def main(): ...@@ -70,17 +82,16 @@ def main():
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}")
# set config # set config
config = set_config(args) config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
if config.parallel: if config.parallel:
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
set_parallel_config(config) set_parallel_config(config)
print_config(config)
with ProfilingContext("Total Cost"): with ProfilingContext("Total Cost"):
runner = init_runner(config) runner = init_runner(config)
runner.run_pipeline() runner.run_pipeline()
......
...@@ -540,7 +540,6 @@ class T5EncoderModel: ...@@ -540,7 +540,6 @@ class T5EncoderModel:
t5_quantized=False, t5_quantized=False,
t5_quantized_ckpt=None, t5_quantized_ckpt=None,
quant_scheme=None, quant_scheme=None,
seq_p_group=None,
): ):
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
......
...@@ -418,13 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r ...@@ -418,13 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel: class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, seq_p_group=None): def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.quantized = clip_quantized self.quantized = clip_quantized
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.use_31_block = use_31_block self.use_31_block = use_31_block
self.seq_p_group = seq_p_group
if self.quantized: if self.quantized:
self.checkpoint_path = clip_quantized_ckpt self.checkpoint_path = clip_quantized_ckpt
......
...@@ -16,8 +16,8 @@ class WanAudioModel(WanModel): ...@@ -16,8 +16,8 @@ class WanAudioModel(WanModel):
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None): def __init__(self, model_path, config, device):
super().__init__(model_path, config, device, seq_p_group) super().__init__(model_path, config, device)
def _init_infer_class(self): def _init_infer_class(self):
super()._init_infer_class() super()._init_infer_class()
......
...@@ -23,8 +23,8 @@ class WanCausVidModel(WanModel): ...@@ -23,8 +23,8 @@ class WanCausVidModel(WanModel):
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None): def __init__(self, model_path, config, device):
super().__init__(model_path, config, device, seq_p_group) super().__init__(model_path, config, device)
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
......
...@@ -19,8 +19,8 @@ class WanDistillModel(WanModel): ...@@ -19,8 +19,8 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None): def __init__(self, model_path, config, device):
super().__init__(model_path, config, device, seq_p_group) super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill # For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
......
...@@ -74,10 +74,6 @@ class WanPreInfer: ...@@ -74,10 +74,6 @@ class WanPreInfer:
x = x.flatten(2).transpose(1, 2).contiguous() x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0) seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
# wan2.2_moe会对t做扩展,我们发现这里做不做影响不大,而且做了拓展会增加耗时,目前忠实原作代码,后续可以考虑去掉
if self.config["model_cls"] == "wan2.2_moe":
t = t.expand(seq_lens[0])
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg: if self.enable_dynamic_cfg:
s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device) s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device)
......
...@@ -41,12 +41,16 @@ class WanModel: ...@@ -41,12 +41,16 @@ class WanModel:
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None): def __init__(self, model_path, config, device):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.cpu_offload = self.config.get("cpu_offload", False) self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block") self.offload_granularity = self.config.get("offload_granularity", "block")
self.seq_p_group = seq_p_group
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default" self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
...@@ -252,8 +256,6 @@ class WanModel: ...@@ -252,8 +256,6 @@ class WanModel:
if target_device == "cuda": if target_device == "cuda":
dist.barrier(device_ids=[torch.cuda.current_device()]) dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
for key in sorted(synced_meta_dict.keys()): for key in sorted(synced_meta_dict.keys()):
if is_weight_loader: if is_weight_loader:
...@@ -390,11 +392,11 @@ class WanModel: ...@@ -390,11 +392,11 @@ class WanModel:
x = F.pad(x, (0, 0, 0, padding_size)) # (后维度填充, 前维度填充) x = F.pad(x, (0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
x = torch.chunk(x, world_size, dim=0)[cur_rank] x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"].startswith("wan2.2"): # if self.config["model_cls"] == "wan2.2":
padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size # padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size
if padding_size > 0: # if padding_size > 0:
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充) # embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
embed = F.pad(embed, (0, 0, 0, padding_size)) # embed = F.pad(embed, (0, 0, 0, padding_size))
pre_infer_out.x = x pre_infer_out.x = x
pre_infer_out.embed = embed pre_infer_out.embed = embed
......
...@@ -343,7 +343,7 @@ class VideoGenerator: ...@@ -343,7 +343,7 @@ class VideoGenerator:
self.model.scheduler.reset() self.model.scheduler.reset()
inputs["previmg_encoder_output"] = self.prepare_prev_latents(prev_video, prev_frame_length) inputs["previmg_encoder_output"] = self.prepare_prev_latents(prev_video, prev_frame_length)
# Run inference loop # Run inference loop
if total_steps is None: if total_steps is None:
total_steps = self.model.scheduler.infer_steps total_steps = self.model.scheduler.infer_steps
......
...@@ -29,7 +29,6 @@ class WanCausVidRunner(WanRunner): ...@@ -29,7 +29,6 @@ class WanCausVidRunner(WanRunner):
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs: for lora_config in self.config.lora_configs:
......
...@@ -21,7 +21,6 @@ class WanDistillRunner(WanRunner): ...@@ -21,7 +21,6 @@ class WanDistillRunner(WanRunner):
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs: for lora_config in self.config.lora_configs:
...@@ -91,7 +90,6 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -91,7 +90,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "high_noise_model"), os.path.join(self.config.model_path, "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
high_lora_wrapper = WanLoraWrapper(high_noise_model) high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config.lora_configs: for lora_config in self.config.lora_configs:
...@@ -106,7 +104,6 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -106,7 +104,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "distill_models", "high_noise_model"), os.path.join(self.config.model_path, "distill_models", "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
if use_low_lora: if use_low_lora:
......
...@@ -34,18 +34,12 @@ from lightx2v.utils.utils import best_output_size, cache_video ...@@ -34,18 +34,12 @@ from lightx2v.utils.utils import best_output_size, cache_video
class WanRunner(DefaultRunner): class WanRunner(DefaultRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
device_mesh = self.config.get("device_mesh")
if device_mesh is not None:
self.seq_p_group = device_mesh.get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
def load_transformer(self): def load_transformer(self):
model = WanModel( model = WanModel(
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
self.seq_p_group,
) )
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
...@@ -83,7 +77,6 @@ class WanRunner(DefaultRunner): ...@@ -83,7 +77,6 @@ class WanRunner(DefaultRunner):
clip_quantized=clip_quantized, clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt, clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme, quant_scheme=clip_quant_scheme,
seq_p_group=self.seq_p_group,
cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)), cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)),
use_31_block=self.config.get("use_31_block", True), use_31_block=self.config.get("use_31_block", True),
) )
...@@ -127,7 +120,6 @@ class WanRunner(DefaultRunner): ...@@ -127,7 +120,6 @@ class WanRunner(DefaultRunner):
t5_quantized=t5_quantized, t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt, t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme, quant_scheme=t5_quant_scheme,
seq_p_group=self.seq_p_group,
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
return text_encoders return text_encoders
...@@ -145,7 +137,6 @@ class WanRunner(DefaultRunner): ...@@ -145,7 +137,6 @@ class WanRunner(DefaultRunner):
"device": vae_device, "device": vae_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1, "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
"seq_p_group": self.seq_p_group,
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
} }
if self.config.task != "i2v": if self.config.task != "i2v":
......
...@@ -60,12 +60,9 @@ class WanScheduler(BaseScheduler): ...@@ -60,12 +60,9 @@ class WanScheduler(BaseScheduler):
device=self.device, device=self.device,
generator=self.generator, generator=self.generator,
) )
if self.config["model_cls"] == "wan2.2": if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
if self.config["task"] == "t2v": self.mask = masks_like(self.latents, zero=True)
self.mask = masks_like(self.latents, zero=False) self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents
elif self.config["task"] == "i2v":
self.mask = masks_like(self.latents, zero=True)
self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents
def set_timesteps( def set_timesteps(
self, self,
......
...@@ -759,7 +759,7 @@ class WanVAE_(nn.Module): ...@@ -759,7 +759,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, cpu_offload=False, **kwargs): def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, **kwargs):
""" """
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
""" """
...@@ -795,7 +795,6 @@ class WanVAE: ...@@ -795,7 +795,6 @@ class WanVAE:
device="cuda", device="cuda",
parallel=False, parallel=False,
use_tiling=False, use_tiling=False,
seq_p_group=None,
cpu_offload=False, cpu_offload=False,
): ):
self.dtype = dtype self.dtype = dtype
...@@ -845,7 +844,7 @@ class WanVAE: ...@@ -845,7 +844,7 @@ class WanVAE:
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
# init model # init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, seq_p_group=seq_p_group, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device) self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
def current_device(self): def current_device(self):
return next(self.model.parameters()).device return next(self.model.parameters()).device
......
...@@ -69,9 +69,6 @@ def set_config(args): ...@@ -69,9 +69,6 @@ def set_config(args):
def set_parallel_config(config): def set_parallel_config(config):
if config.parallel: if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cfg_p_size = config.parallel.get("cfg_p_size", 1) cfg_p_size = config.parallel.get("cfg_p_size", 1)
seq_p_size = config.parallel.get("seq_p_size", 1) seq_p_size = config.parallel.get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size" assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
...@@ -82,3 +79,13 @@ def set_parallel_config(config): ...@@ -82,3 +79,13 @@ def set_parallel_config(config):
if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1: if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1:
config["cfg_parallel"] = True config["cfg_parallel"] = True
def print_config(config):
config_to_print = config.copy()
config_to_print.pop("device_mesh", None)
if config.parallel:
if dist.get_rank() == 0:
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
else:
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment