Commit 75eac23c authored by wangshankun's avatar wangshankun
Browse files

Add:Audio CM Scheduler

parent d86b6917
...@@ -6,13 +6,12 @@ ...@@ -6,13 +6,12 @@
"target_video_length": 81, "target_video_length": 81,
"target_height": 480, "target_height": 480,
"target_width": 832, "target_width": 832,
"self_attn_1_type": "radial_attn", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42, "seed": 42,
"sample_guide_scale":1, "sample_guide_scale":1,
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": false, "cpu_offload": false
"use_tiling_vae": true
} }
...@@ -24,8 +24,6 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -24,8 +24,6 @@ class WanAudioPreInfer(WanPreInfer):
self.text_len = config["text_len"] self.text_len = config["text_len"]
def infer(self, weights, inputs, positive): def infer(self, weights, inputs, positive):
ltnt_frames = self.scheduler.latents.size(1)
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0) prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"] prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
......
...@@ -19,7 +19,7 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE ...@@ -19,7 +19,7 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.models.schedulers.wan.audio.scheduler import EulerSchedulerTimestepFix from lightx2v.models.schedulers.wan.audio.scheduler import EulerSchedulerTimestepFix, ConsistencyModelScheduler
from loguru import logger from loguru import logger
import torch.distributed as dist import torch.distributed as dist
...@@ -327,7 +327,7 @@ class WanAudioRunner(WanRunner): ...@@ -327,7 +327,7 @@ class WanAudioRunner(WanRunner):
super().__init__(config) super().__init__(config)
def init_scheduler(self): def init_scheduler(self):
scheduler = EulerSchedulerTimestepFix(self.config) scheduler = ConsistencyModelScheduler(self.config)
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def load_audio_models(self): def load_audio_models(self):
...@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner): ...@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner):
last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device) last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy() last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames) last_frames = add_noise_to_frames(last_frames) # mean:-3.0 std:0.5
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10 last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device) last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
...@@ -583,7 +583,7 @@ class WanAudioRunner(WanRunner): ...@@ -583,7 +583,7 @@ class WanAudioRunner(WanRunner):
latents = self.model.scheduler.latents latents = self.model.scheduler.latents
generator = self.model.scheduler.generator generator = self.model.scheduler.generator
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config) gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gen_video = torch.clamp(gen_video, -1, 1) gen_video = torch.clamp(gen_video, -1, 1).to(torch.float)
start_frame = 0 if idx == 0 else prev_frame_length start_frame = 0 if idx == 0 else prev_frame_length
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps) start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
......
...@@ -34,7 +34,6 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -34,7 +34,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self.infer_steps = self.config.infer_steps self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift self.sample_shift = self.config.sample_shift
self.shift = 1
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
self.step_index = None self.step_index = None
...@@ -94,7 +93,6 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -94,7 +93,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
class ConsistencyModelScheduler(EulerSchedulerTimestepFix): class ConsistencyModelScheduler(EulerSchedulerTimestepFix):
def step_post(self): def step_post(self):
logger.info(f"Step index: {self.step_index}, self.timestep: {self.timesteps[self.step_index]}")
model_output = self.noise_pred.to(torch.float32) model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32) sample = self.latents.to(torch.float32)
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype) sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment