Commit 1bba5529 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge branch 'main' into dev_gsq

parents 4209fd89 6943aa52
{
"infer_steps": 6,
"target_fps": 16,
"video_duration": 16,
"audio_sr": 16000,
"text_len": 512,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [1.0, 1.0],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"use_31_block": false,
"lora_configs": [
{
"name": "high_noise_model",
"path": "/mnt/Text2Video/wuzhuguanyu/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors",
"strength": 1.0
},
{
"name": "low_noise_model",
"path": "/mnt/Text2Video/wuzhuguanyu/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors",
"strength": 1.0
}
]
}
...@@ -13,7 +13,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner ...@@ -13,7 +13,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner, Wan22MoeRunner from lightx2v.models.runners.wan.wan_runner import WanRunner, Wan22MoeRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner, Wan22MoeAudioRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner
......
...@@ -61,8 +61,18 @@ class WanAudioModel(WanModel): ...@@ -61,8 +61,18 @@ class WanAudioModel(WanModel):
if self.scheduler.cnt >= self.scheduler.num_steps: if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0 self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu() self.post_weight.to_cpu()
class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, use_bf16, skip_bf16):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
weight_dict.update(file_weights)
return weight_dict
...@@ -9,7 +9,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -9,7 +9,7 @@ class WanAudioPreInfer(WanPreInfer):
def __init__(self, config): def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
d = config["dim"] // config["num_heads"] d = config["dim"] // config["num_heads"]
self.config = config
self.task = config["task"] self.task = config["task"]
self.freqs = torch.cat( self.freqs = torch.cat(
[ [
...@@ -22,6 +22,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -22,6 +22,7 @@ class WanAudioPreInfer(WanPreInfer):
self.freq_dim = config["freq_dim"] self.freq_dim = config["freq_dim"]
self.dim = config["dim"] self.dim = config["dim"]
self.text_len = config["text_len"] self.text_len = config["text_len"]
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
def infer(self, weights, inputs, positive): def infer(self, weights, inputs, positive):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0) prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
...@@ -93,13 +94,20 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -93,13 +94,20 @@ class WanAudioPreInfer(WanPreInfer):
out = torch.nn.functional.gelu(out, approximate="tanh") out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out) context = weights.text_embedding_2.apply(out)
if self.task == "i2v": if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_clip = weights.proj_0.apply(clip_fea) context_clip = weights.proj_0.apply(clip_fea)
context_clip = weights.proj_1.apply(context_clip) context_clip = weights.proj_1.apply(context_clip)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none") context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
context_clip = weights.proj_3.apply(context_clip) context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip) context_clip = weights.proj_4.apply(context_clip)
if self.clean_cuda_cache:
del clip_fea
torch.cuda.empty_cache()
context = torch.concat([context_clip, context], dim=0) context = torch.concat([context_clip, context], dim=0)
if self.clean_cuda_cache:
if self.config.get("use_image_encoder", True):
del context_clip
torch.cuda.empty_cache()
return (embed, x_grid_sizes, (x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context, audio_dit_blocks), valid_patch_length) return (embed, x_grid_sizes, (x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context, audio_dit_blocks), valid_patch_length)
...@@ -11,11 +11,12 @@ from dataclasses import dataclass ...@@ -11,11 +11,12 @@ from dataclasses import dataclass
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.networks.wan.audio_model import WanAudioModel from lightx2v.models.networks.wan.audio_model import WanAudioModel, Wan22MoeAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from .wan_runner import MultiModelStruct
from loguru import logger from loguru import logger
from einops import rearrange from einops import rearrange
...@@ -262,7 +263,7 @@ class VideoGenerator: ...@@ -262,7 +263,7 @@ class VideoGenerator:
if prev_video is None: if prev_video is None:
return None return None
device = self.model.device device = torch.device("cuda")
dtype = torch.bfloat16 dtype = torch.bfloat16
vae_dtype = torch.float vae_dtype = torch.float
...@@ -315,7 +316,7 @@ class VideoGenerator: ...@@ -315,7 +316,7 @@ class VideoGenerator:
self.model.scheduler.reset() self.model.scheduler.reset()
# Prepare previous latents - ALWAYS needed, even for first segment # Prepare previous latents - ALWAYS needed, even for first segment
device = self.model.device device = torch.device("cuda")
dtype = torch.bfloat16 dtype = torch.bfloat16
vae_dtype = torch.float vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
...@@ -423,7 +424,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -423,7 +424,7 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False) audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
# Audio encoder # Audio encoder
device = self.model.device device = torch.device("cuda")
audio_encoder_repo = self.config["model_path"] + "/audio_encoder" audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0) self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0)
...@@ -655,7 +656,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -655,7 +656,7 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic") cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
# clip encoder # clip encoder
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16) clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16) if self.config.get("use_image_encoder", True) else None
# vae encode # vae encode
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W") cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
...@@ -684,3 +685,44 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -684,3 +685,44 @@ class WanAudioRunner(WanRunner): # type:ignore
ret["target_shape"] = self.config.target_shape ret["target_shape"] = self.config.target_shape
return ret return ret
@RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = Wan22MoeAudioModel(
os.path.join(self.config.model_path, "high_noise_model"),
self.config,
self.init_device,
)
low_noise_model = Wan22MoeAudioModel(
os.path.join(self.config.model_path, "low_noise_model"),
self.config,
self.init_device,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
if lora_config.name == "high_noise_model":
lora_wrapper = WanLoraWrapper(high_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"{lora_config.name} Loaded LoRA: {lora_name} with strength: {strength}")
if lora_config.name == "low_noise_model":
lora_wrapper = WanLoraWrapper(low_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"{lora_config.name} Loaded LoRA: {lora_name} with strength: {strength}")
# XXX: trick
self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
...@@ -307,6 +307,10 @@ class MultiModelStruct: ...@@ -307,6 +307,10 @@ class MultiModelStruct:
self.cur_model_index = -1 self.cur_model_index = -1
logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}") logger.info(f"boundary: {self.boundary}, boundary_timestep: {self.boundary_timestep}")
@property
def device(self):
return self.model[self.cur_model_index].device
def set_scheduler(self, shared_scheduler): def set_scheduler(self, shared_scheduler):
self.scheduler = shared_scheduler self.scheduler = shared_scheduler
for model in self.model: for model in self.model:
......
...@@ -852,13 +852,20 @@ class WanVAE: ...@@ -852,13 +852,20 @@ class WanVAE:
.to(device) .to(device)
) )
def current_device(self):
return next(self.model.parameters()).device
def to_cpu(self): def to_cpu(self):
self.model.encoder = self.model.encoder.to("cpu")
self.model.decoder = self.model.decoder.to("cpu")
self.model = self.model.to("cpu") self.model = self.model.to("cpu")
self.mean = self.mean.cpu() self.mean = self.mean.cpu()
self.inv_std = self.inv_std.cpu() self.inv_std = self.inv_std.cpu()
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def to_cuda(self): def to_cuda(self):
self.model.encoder = self.model.encoder.to("cuda")
self.model.decoder = self.model.decoder.to("cuda")
self.model = self.model.to("cuda") self.model = self.model.to("cuda")
self.mean = self.mean.cuda() self.mean = self.mean.cuda()
self.inv_std = self.inv_std.cuda() self.inv_std = self.inv_std.cuda()
...@@ -872,9 +879,9 @@ class WanVAE: ...@@ -872,9 +879,9 @@ class WanVAE:
self.to_cuda() self.to_cuda()
if self.use_tiling: if self.use_tiling:
out = [self.model.tiled_encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] out = [self.model.tiled_encode(u.unsqueeze(0).to(self.current_device()), self.scale).float().squeeze(0) for u in videos]
else: else:
out = [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] out = [self.model.encode(u.unsqueeze(0).to(self.current_device()), self.scale).float().squeeze(0) for u in videos]
if hasattr(args, "cpu_offload") and args.cpu_offload: if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cpu() self.to_cpu()
......
#!/bin/bash
# set path and first
lightx2v_path="/mnt/Text2Video2/wangshankun/lightx2v"
model_path="/mnt/Text2Video/wangshankun/HF_Cache/Wan2.2-I2V-A14B"
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
#-m debugpy --wait-for-client --listen 0.0.0.0:15684 \
python \
-m lightx2v.infer \
--model_cls wan2.2_moe_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_i2v_audio.json \
--prompt "The video features a old lady is saying something and knitting a sweater." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_audio.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment