"vscode:/vscode.git/clone" did not exist on "76edff62f58dbaed25f3b047ecf56fe5fec8e7c1"
Commit a4818f0f authored by helloyongyang's avatar helloyongyang
Browse files

[Feature]: support changing_resolution

parent 3e22e549
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"changing_resolution": true
}
...@@ -7,6 +7,7 @@ from PIL import Image ...@@ -7,6 +7,7 @@ from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import WanScheduler4ChangingResolution
from lightx2v.models.schedulers.wan.feature_caching.scheduler import ( from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerTeaCaching, WanSchedulerTeaCaching,
WanSchedulerTaylorCaching, WanSchedulerTaylorCaching,
...@@ -119,18 +120,21 @@ class WanRunner(DefaultRunner): ...@@ -119,18 +120,21 @@ class WanRunner(DefaultRunner):
return vae_encoder, vae_decoder return vae_encoder, vae_decoder
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config.get("changing_resolution", False):
scheduler = WanScheduler(self.config) scheduler = WanScheduler4ChangingResolution(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") if self.config.feature_caching == "NoCaching":
scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img): def run_text_encoder(self, text, img):
......
import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanScheduler4ChangingResolution(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.resolution_rate = 0.75
self.changing_resolution_steps = 25
def prepare_latents(self, target_shape, dtype=torch.float32):
self.latents = torch.randn(
target_shape[0],
target_shape[1],
int(target_shape[2] * self.resolution_rate) // 2 * 2,
int(target_shape[3] * self.resolution_rate) // 2 * 2,
dtype=dtype,
device=self.device,
generator=self.generator,
)
self.noise_original_resolution = torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
def step_post(self):
if self.step_index == self.changing_resolution_steps:
self.step_post_upsample()
else:
super().step_post()
def step_post_upsample(self):
# 1. denoised sample to clean noise
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
denoised_sample = x0_pred.to(sample.dtype)
# 2. upsample clean noise to target shape
denoised_sample_5d = denoised_sample.unsqueeze(0) # (C,T,H,W) -> (1,C,T,H,W)
clean_noise = torch.nn.functional.interpolate(
denoised_sample_5d,
size=(self.config.target_shape[1], self.config.target_shape[2], self.config.target_shape[3]),
mode='trilinear'
)
clean_noise = clean_noise.squeeze(0) # (1,C,T,H,W) -> (C,T,H,W)
# 3. add noise to clean noise
noisy_sample = self.add_noise(clean_noise, self.noise_original_resolution, self.timesteps[self.step_index + 1])
# 4. update latents
self.latents = noisy_sample
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + 2 更激进的去噪
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + 2)
def add_noise(self, original_samples, noise, timesteps):
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/changing_resolution/wan_t2v.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_changing_resolution.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment