Commit 2a9a64d0 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix audio model offload bug. (#247)

* [Fix] Fix audio model offload bug.

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix
parent abeb9bc8
......@@ -24,5 +24,7 @@
"t5_cpu_offload": true,
"offload_ratio_val": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true
"use_tiling_vae": true,
"audio_encoder_cpu_offload": true,
"audio_adapter_cpu_offload": false
}
......@@ -17,8 +17,10 @@
"use_31_block": false,
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": true,
"offload_ratio_val": 1,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"use_tiling_vae": true
"use_tiling_vae": true,
"audio_encoder_cpu_offload": true,
"audio_adapter_cpu_offload": false
}
......@@ -18,5 +18,7 @@
"t5_cpu_offload": true,
"offload_ratio_val": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true
"use_tiling_vae": true,
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": true,
"use_31_block": false,
"adaptive_resize": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
},
"t5_cpu_offload": true,
"t5_quantized": true,
"t5_quant_scheme": "fp8",
"clip_quantized": true,
"clip_quant_scheme": "fp8",
"use_tiling_vae": true,
"use_tiny_vae": true,
"lazy_load": true,
"rotary_chunk": true,
"clean_cuda_cache": true,
"audio_encoder_cpu_offload": true,
"audio_adapter_cpu_offload": true
}
......@@ -3,87 +3,16 @@ try:
except ModuleNotFoundError:
flash_attn = None
import math
import os
import safetensors
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import SglQuantLinearFp8
def load_safetensors(in_path: str):
if os.path.isdir(in_path):
return load_safetensors_from_dir(in_path)
elif os.path.isfile(in_path):
return load_safetensors_from_path(in_path)
else:
raise ValueError(f"{in_path} does not exist")
def load_safetensors_from_path(in_path: str):
tensors = {}
with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors
def load_safetensors_from_dir(in_dir: str):
tensors = {}
safetensors = os.listdir(in_dir)
safetensors = [f for f in safetensors if f.endswith(".safetensors")]
for f in safetensors:
tensors.update(load_safetensors_from_path(os.path.join(in_dir, f)))
return tensors
def load_pt_safetensors(in_path: str):
ext = os.path.splitext(in_path)[-1]
if ext in (".pt", ".pth", ".tar"):
state_dict = torch.load(in_path, map_location="cpu", weights_only=True)
else:
state_dict = load_safetensors(in_path)
return state_dict
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
model = model.to("cuda")
# 确定当前进程是否是(负责加载权重)
is_leader = False
if dist.is_initialized():
current_rank = dist.get_rank()
if current_rank == 0:
is_leader = True
elif not dist.is_initialized() or dist.get_rank() == 0:
is_leader = True
if is_leader:
logger.info(f"Loading model state from {in_path}")
state_dict = load_pt_safetensors(in_path)
model.load_state_dict(state_dict, strict=strict)
# 将模型状态从领导者同步到组内所有其他进程
if dist.is_initialized():
dist.barrier(device_ids=[torch.cuda.current_device()])
src_global_rank = 0
for param in model.parameters():
dist.broadcast(param.data, src=src_global_rank)
for buffer in model.buffers():
dist.broadcast(buffer.data, src=src_global_rank)
elif dist.is_initialized():
dist.barrier(device_ids=[torch.cuda.current_device()])
for param in model.parameters():
dist.broadcast(param.data, src=0)
for buffer in model.buffers():
dist.broadcast(buffer.data, src=0)
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(features, size=output_len, align_corners=False, mode="linear")
......@@ -265,8 +194,10 @@ class AudioAdapter(nn.Module):
projection_transformer_layers: int = 4,
quantized: bool = False,
quant_scheme: str = None,
cpu_offload: bool = False,
):
super().__init__()
self.cpu_offload = cpu_offload
self.audio_proj = AudioProjection(
audio_feature_dim=audio_feature_dim,
n_neighbors=(2, 2),
......@@ -309,7 +240,11 @@ class AudioAdapter(nn.Module):
@torch.no_grad()
def forward_audio_proj(self, audio_feat, latent_frame):
if self.cpu_offload:
self.audio_proj.to("cuda")
x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x)
x = x + self.audio_pe
x = x + self.audio_pe.cuda()
if self.cpu_offload:
self.audio_proj.to("cpu")
return x
......@@ -5,14 +5,20 @@ from lightx2v.utils.envs import *
class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr):
def __init__(self, model_path, audio_sr, cpu_offload):
self.model_path = model_path
self.audio_sr = audio_sr
self.cpu_offload = cpu_offload
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device("cuda")
self.load()
def load(self):
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_path)
self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path)
self.audio_feature_encoder.to(self.device)
self.audio_feature_encoder.eval()
self.audio_feature_encoder.to(GET_DTYPE())
......@@ -24,6 +30,10 @@ class SekoAudioEncoderModel:
@torch.no_grad()
def infer(self, audio_segment):
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.audio_feature_encoder.device).to(dtype=GET_DTYPE())
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.cuda().to(dtype=GET_DTYPE())
if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cuda")
audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state
if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
return audio_feat
......@@ -106,7 +106,6 @@ class WanAudioPreInfer(WanPreInfer):
context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip)
if self.clean_cuda_cache:
del clip_fea
torch.cuda.empty_cache()
context = torch.concat([context_clip, context], dim=0)
......
......@@ -28,7 +28,7 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
x = self.modify_hidden_states(
hidden_states=x,
hidden_states=x.to(self.infer_dtype),
grid_sizes=pre_infer_out.grid_sizes,
ca_block=self.audio_adapter.ca[self.block_idx],
audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"],
......@@ -87,8 +87,11 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数
if self.audio_adapter.cpu_offload:
ca_block.to("cuda")
residual = ca_block(audio_encoder_output, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
if self.audio_adapter.cpu_offload:
ca_block.to("cpu")
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
......
......@@ -97,9 +97,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if self.clean_cuda_cache:
del (
pre_infer_out.grid_sizes,
pre_infer_out.embed0,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
pre_infer_out.context,
)
......@@ -235,9 +233,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self.phase_params["c_gate_msa"],
)
del (
pre_infer_out.grid_sizes,
pre_infer_out.embed0,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
pre_infer_out.context,
)
......
......@@ -140,7 +140,6 @@ class DefaultRunner(BaseRunner):
if hasattr(self.model.transformer_weights, "clear"):
self.model.transformer_weights.clear()
self.model.pre_weight.clear()
self.model.post_weight.clear()
del self.model
torch.cuda.empty_cache()
gc.collect()
......
......@@ -16,7 +16,7 @@ from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from transformers import AutoFeatureExtractor
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter, rank0_load_state_dict_from_path
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
......@@ -26,7 +26,7 @@ from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, save_to_video, vae_to_comfyui_image
from lightx2v.utils.utils import find_torch_model_path, load_weights, save_to_video, vae_to_comfyui_image
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
......@@ -238,14 +238,14 @@ class AudioProcessor:
class WanAudioRunner(WanRunner): # type:ignore
def __init__(self, config):
super().__init__(config)
self._audio_processor = None
self._video_generator = None
self._audio_preprocess = None
self.frame_preprocessor = FramePreprocessor()
def init_scheduler(self):
"""Initialize consistency model scheduler"""
scheduler = ConsistencyModelScheduler(self.config)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.audio_adapter = self.load_audio_adapter()
self.model.set_audio_adapter(self.audio_adapter)
scheduler.set_audio_adapter(self.audio_adapter)
self.model.set_scheduler(scheduler)
......@@ -288,12 +288,25 @@ class WanAudioRunner(WanRunner): # type:ignore
return ref_img
def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder()
clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder
torch.cuda.empty_cache()
gc.collect()
return clip_encoder_out
def run_vae_encoder(self, img):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder()
img = rearrange(img, "1 C H W -> 1 C 1 H W")
vae_encoder_out = self.vae_encoder.encode(img.to(torch.float))[0]
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
torch.cuda.empty_cache()
gc.collect()
return vae_encoder_out
@ProfilingContext("Run Encoders")
......@@ -331,6 +344,9 @@ class WanAudioRunner(WanRunner): # type:ignore
last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
prev_frames[:, :, :prev_frame_length] = last_frames
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder()
_, nframe, height, width = self.model.scheduler.latents.shape
if self.config.model_cls == "wan2.2_audio":
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype)
......@@ -353,6 +369,11 @@ class WanAudioRunner(WanRunner): # type:ignore
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
torch.cuda.empty_cache()
gc.collect()
return {"prev_latents": prev_latents, "prev_mask": prev_mask}
def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor:
......@@ -386,7 +407,10 @@ class WanAudioRunner(WanRunner): # type:ignore
torch.manual_seed(self.config.seed)
logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
audio_features = self.audio_encoder.infer(self.segment.audio_array).to(self.model.device)
if (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)) and not hasattr(self, "audio_encoder"):
self.audio_encoder = self.load_audio_encoder()
audio_features = self.audio_encoder.infer(self.segment.audio_array)
audio_features = self.audio_adapter.forward_audio_proj(audio_features, self.model.scheduler.latents.shape[1])
self.inputs["audio_encoder_output"] = audio_features
......@@ -507,10 +531,17 @@ class WanAudioRunner(WanRunner): # type:ignore
return base_model
def load_audio_encoder(self):
model = SekoAudioEncoderModel(os.path.join(self.config["model_path"], "audio_encoder"), self.config["audio_sr"])
audio_encoder_path = os.path.join(self.config["model_path"], "audio_encoder")
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload)
return model
def load_audio_adapter(self):
audio_adapter_offload = self.config.get("audio_adapter_cpu_offload", self.config.get("cpu_offload", False))
if audio_adapter_offload:
device = torch.device("cpu")
else:
device = torch.device("cuda")
audio_adapter = AudioAdapter(
attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"],
......@@ -522,7 +553,9 @@ class WanAudioRunner(WanRunner): # type:ignore
mlp_dims=(1024, 1024, 32 * 1024),
quantized=self.config.get("adapter_quantized", False),
quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload,
)
audio_adapter.to(device)
if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) == "fp8":
model_name = "audio_adapter_fp8.safetensors"
......@@ -532,7 +565,9 @@ class WanAudioRunner(WanRunner): # type:ignore
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
else:
model_name = "audio_adapter.safetensors"
rank0_load_state_dict_from_path(audio_adapter, os.path.join(self.config["model_path"], model_name), strict=False)
weights_dict = load_weights(os.path.join(self.config["model_path"], model_name), cpu_offload=audio_adapter_offload)
audio_adapter.load_state_dict(weights_dict, strict=False)
return audio_adapter.to(dtype=GET_DTYPE())
@ProfilingContext("Load models")
......
......@@ -88,6 +88,8 @@ class WanVaceRunner(WanRunner):
return src_video, src_mask, src_ref_images
def run_vae_encoder(self, frames, ref_images, masks):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder()
if ref_images is None:
ref_images = [None] * len(frames)
else:
......@@ -115,6 +117,10 @@ class WanVaceRunner(WanRunner):
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
self.latent_shape = list(cat_latents[0].shape)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
torch.cuda.empty_cache()
gc.collect()
return self.get_vae_encoder_output(cat_latents, masks, ref_images)
def get_vae_encoder_output(self, cat_latents, masks, ref_images):
......
......@@ -17,7 +17,11 @@ class ConsistencyModelScheduler(WanScheduler):
def step_pre(self, step_index):
super().step_pre(step_index)
if self.audio_adapter.cpu_offload:
self.audio_adapter.time_embedding.to("cuda")
self.audio_adapter_t_emb = self.audio_adapter.time_embedding(self.timestep_input).unflatten(1, (3, -1))
if self.audio_adapter.cpu_offload:
self.audio_adapter.time_embedding.to("cpu")
def prepare(self, image_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
......
......@@ -7,6 +7,7 @@ from typing import Optional
import imageio
import imageio_ffmpeg as ffmpeg
import numpy as np
import safetensors
import torch
import torch.distributed as dist
import torchvision
......@@ -267,9 +268,9 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
]
if isinstance(subdir, list):
for sub in subdir:
paths_to_check.append(os.path.join(config.model_path, sub, filename))
paths_to_check.insert(0, os.path.join(config.model_path, sub, filename))
else:
paths_to_check.append(os.path.join(config.model_path, subdir, filename))
paths_to_check.insert(0, os.path.join(config.model_path, subdir, filename))
for path in paths_to_check:
if os.path.exists(path):
......@@ -284,10 +285,9 @@ def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["origin
paths_to_check = [model_path]
if isinstance(subdir, list):
for sub in subdir:
paths_to_check.append(os.path.join(model_path, sub))
paths_to_check.insert(0, os.path.join(model_path, sub))
else:
paths_to_check.append(os.path.join(model_path, subdir))
paths_to_check.insert(0, os.path.join(model_path, subdir))
for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
......@@ -324,10 +324,45 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_safetensors(in_path: str):
if os.path.isdir(in_path):
return load_safetensors_from_dir(in_path)
elif os.path.isfile(in_path):
return load_safetensors_from_path(in_path)
else:
raise ValueError(f"{in_path} does not exist")
def load_safetensors_from_path(in_path: str):
tensors = {}
with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors
def load_safetensors_from_dir(in_dir: str):
tensors = {}
safetensors = os.listdir(in_dir)
safetensors = [f for f in safetensors if f.endswith(".safetensors")]
for f in safetensors:
tensors.update(load_safetensors_from_path(os.path.join(in_dir, f)))
return tensors
def load_pt_safetensors(in_path: str):
ext = os.path.splitext(in_path)[-1]
if ext in (".pt", ".pth", ".tar"):
state_dict = torch.load(in_path, map_location="cpu", weights_only=True)
else:
state_dict = load_safetensors(in_path)
return state_dict
def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
if not dist.is_initialized():
# Single GPU mode
cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
cpu_weight_dict = load_pt_safetensors(checkpoint_path)
for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key:
cpu_weight_dict.pop(key)
......@@ -343,7 +378,7 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
cpu_weight_dict = {}
if is_weight_loader:
logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
cpu_weight_dict = load_pt_safetensors(checkpoint_path)
for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key:
cpu_weight_dict.pop(key)
......
......@@ -49,7 +49,7 @@ export ENABLE_GRAPH_MODE=true
echo "==============================================================================="
echo "LightX2V Environment Variables Summary:"
echo "LightX2V Base Environment Variables Summary:"
echo "-------------------------------------------------------------------------------"
echo "lightx2v_path: ${lightx2v_path}"
echo "model_path: ${model_path}"
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
lightx2v_path=/mtc/gushiqiao/llmc_workspace/LightX2V
model_path=/data/nvme0/gushiqiao/models/Lightx2v_models/Wan2.1-R2V721-Audio-14B-720P
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=2
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export DTYPE=FP16
export SENSITIVE_LAYER_DTYPE=FP16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
echo "==============================================================================="
echo "LightX2V Lazyload Environment Variables Summary:"
echo "-------------------------------------------------------------------------------"
echo "lightx2v_path: ${lightx2v_path}"
echo "model_path: ${model_path}"
echo "-------------------------------------------------------------------------------"
echo "Model Inference Data Type: ${DTYPE}"
echo "Sensitive Layer Data Type: ${SENSITIVE_LAYER_DTYPE}"
echo "Performance Profiling Debug Mode: ${ENABLE_PROFILING_DEBUG}"
echo "Graph Mode Optimization: ${ENABLE_GRAPH_MODE}"
echo "==============================================================================="
python -m lightx2v.infer \
--model_cls wan2.1_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/offload/disk/wan_i2v_audio_phase_lazy_load_720p.json \
--prompt "The video features a old lady is saying something and knitting a sweater." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ENABLE_GRAPH_MODE=false
export SENSITIVE_LAYER_DTYPE=None
python -m lightx2v.infer \
--model_cls wan2.1_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/audio_driven/wan_i2v_audio_offload.json \
--prompt "The video features a old lady is saying something and knitting a sweater." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4
......@@ -8,12 +8,29 @@ export CUDA_VISIBLE_DEVICES=0
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export DTYPE=FP16
export SENSITIVE_LAYER_DTYPE=FP16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
echo "==============================================================================="
echo "LightX2V Lazyload Environment Variables Summary:"
echo "-------------------------------------------------------------------------------"
echo "lightx2v_path: ${lightx2v_path}"
echo "model_path: ${model_path}"
echo "-------------------------------------------------------------------------------"
echo "Model Inference Data Type: ${DTYPE}"
echo "Sensitive Layer Data Type: ${SENSITIVE_LAYER_DTYPE}"
echo "Performance Profiling Debug Mode: ${ENABLE_PROFILING_DEBUG}"
echo "Graph Mode Optimization: ${ENABLE_GRAPH_MODE}"
echo "==============================================================================="
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json /data/video_gen/lightx2v_latest/lightx2v/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json \
--config_json ${lightx2v_path}/configs/offload/disk/wan_i2v_phase_lazy_load_720p.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment