Commit cb359e19 authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

Support fixed_shape resize for seko & update scheduler for seko (#270)

parent af5105c7
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "fixed_shape",
"fixed_shape": [240, 320],
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
},
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
},
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"t5_quantized": true,
"t5_quant_scheme": "fp8"
}
...@@ -20,7 +20,7 @@ from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudio ...@@ -20,7 +20,7 @@ from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudio
from lightx2v.models.networks.wan.audio_model import WanAudioModel from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import EulerScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
...@@ -80,8 +80,34 @@ def isotropic_crop_resize(frames: torch.Tensor, size: tuple): ...@@ -80,8 +80,34 @@ def isotropic_crop_resize(frames: torch.Tensor, size: tuple):
return resized_frames return resized_frames
def resize_image(img, resize_mode="adaptive", fixed_area=None): def fixed_shape_resize(img, target_height, target_width):
assert resize_mode in ["adaptive", "keep_ratio_fixed_area", "fixed_min_area", "fixed_max_area"] orig_height, orig_width = img.shape[-2:]
target_ratio = target_height / target_width
orig_ratio = orig_height / orig_width
if orig_ratio > target_ratio:
crop_width = orig_width
crop_height = int(crop_width * target_ratio)
else:
crop_height = orig_height
crop_width = int(crop_height / target_ratio)
cropped_img = TF.center_crop(img, [crop_height, crop_width])
resized_img = TF.resize(cropped_img, [target_height, target_width], antialias=True)
h, w = resized_img.shape[-2:]
return resized_img, h, w
def resize_image(img, resize_mode="adaptive", fixed_area=None, fixed_shape=None):
assert resize_mode in ["adaptive", "keep_ratio_fixed_area", "fixed_min_area", "fixed_max_area", "fixed_shape"]
if resize_mode == "fixed_shape":
assert fixed_shape is not None
logger.info(f"[wan_audio] fixed_shape_resize fixed_height: {fixed_shape[0]}, fixed_width: {fixed_shape[1]}")
return fixed_shape_resize(img, fixed_shape[0], fixed_shape[1])
bucket_config = { bucket_config = {
0.667: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64), np.array([0.2, 0.5, 0.3])), 0.667: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64), np.array([0.2, 0.5, 0.3])),
...@@ -261,7 +287,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -261,7 +287,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def init_scheduler(self): def init_scheduler(self):
"""Initialize consistency model scheduler""" """Initialize consistency model scheduler"""
scheduler = ConsistencyModelScheduler(self.config) scheduler = EulerScheduler(self.config)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.audio_adapter = self.load_audio_adapter() self.audio_adapter = self.load_audio_adapter()
self.model.set_audio_adapter(self.audio_adapter) self.model.set_audio_adapter(self.audio_adapter)
...@@ -289,7 +315,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -289,7 +315,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = Image.open(img_path).convert("RGB") ref_img = Image.open(img_path).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda() ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
ref_img, h, w = resize_image(ref_img, resize_mode=self.config.get("resize_mode", "adaptive"), fixed_area=self.config.get("fixed_area", None)) ref_img, h, w = resize_image(ref_img, resize_mode=self.config.get("resize_mode", "adaptive"), fixed_area=self.config.get("fixed_area", None), fixed_shape=self.config.get("fixed_shape", None))
logger.info(f"[wan_audio] resize_image target_h: {h}, target_w: {w}") logger.info(f"[wan_audio] resize_image target_h: {h}, target_w: {w}")
patched_h = h // self.config.vae_stride[1] // self.config.patch_size[1] patched_h = h // self.config.vae_stride[1] // self.config.patch_size[1]
patched_w = w // self.config.vae_stride[2] // self.config.patch_size[2] patched_w = w // self.config.vae_stride[2] // self.config.patch_size[2]
......
...@@ -10,7 +10,7 @@ from lightx2v.utils.envs import * ...@@ -10,7 +10,7 @@ from lightx2v.utils.envs import *
from lightx2v.utils.utils import masks_like from lightx2v.utils.utils import masks_like
class ConsistencyModelScheduler(WanScheduler): class EulerScheduler(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -89,8 +89,7 @@ class ConsistencyModelScheduler(WanScheduler): ...@@ -89,8 +89,7 @@ class ConsistencyModelScheduler(WanScheduler):
sample = self.latents.to(torch.float32) sample = self.latents.to(torch.float32)
sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype) sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype) sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
x0 = sample - model_output * sigma x_t_next = sample + (sigma_next - sigma) * model_output
x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
self.latents = x_t_next self.latents = x_t_next
if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None: if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
...@@ -110,3 +109,16 @@ class ConsistencyModelScheduler(WanScheduler): ...@@ -110,3 +109,16 @@ class ConsistencyModelScheduler(WanScheduler):
if in_tensor.ndim < tgt_n_dim: if in_tensor.ndim < tgt_n_dim:
in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)] in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
return in_tensor return in_tensor
class ConsistencyModelScheduler(EulerScheduler):
def step_post(self):
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
x0 = sample - model_output * sigma
x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
self.latents = x_t_next
if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
#!/bin/bash
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0,1,2,3
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ENABLE_GRAPH_MODE=false
export SENSITIVE_LAYER_DTYPE=None
torchrun --nproc-per-node 4 -m lightx2v.infer \
--model_cls seko_talk \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_11_fp8_dist_fixed_shape.json \
--prompt "The video features a old lady is saying something and knitting a sweater." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment