"vscode:/vscode.git/clone" did not exist on "cdc56ef6c1c6f359de87c5f78a66316723557d5d"
Commit 75eac23c authored by wangshankun's avatar wangshankun
Browse files

Add:Audio CM Scheduler

parent d86b6917
......@@ -6,13 +6,12 @@
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "radial_attn",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale":1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_tiling_vae": true
"cpu_offload": false
}
......@@ -24,8 +24,6 @@ class WanAudioPreInfer(WanPreInfer):
self.text_len = config["text_len"]
def infer(self, weights, inputs, positive):
ltnt_frames = self.scheduler.latents.size(1)
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
......
......@@ -19,7 +19,7 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.models.schedulers.wan.audio.scheduler import EulerSchedulerTimestepFix
from lightx2v.models.schedulers.wan.audio.scheduler import EulerSchedulerTimestepFix, ConsistencyModelScheduler
from loguru import logger
import torch.distributed as dist
......@@ -327,7 +327,7 @@ class WanAudioRunner(WanRunner):
super().__init__(config)
def init_scheduler(self):
scheduler = EulerSchedulerTimestepFix(self.config)
scheduler = ConsistencyModelScheduler(self.config)
self.model.set_scheduler(scheduler)
def load_audio_models(self):
......@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner):
last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames)
last_frames = add_noise_to_frames(last_frames) # mean:-3.0 std:0.5
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
......@@ -583,7 +583,7 @@ class WanAudioRunner(WanRunner):
latents = self.model.scheduler.latents
generator = self.model.scheduler.generator
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gen_video = torch.clamp(gen_video, -1, 1)
gen_video = torch.clamp(gen_video, -1, 1).to(torch.float)
start_frame = 0 if idx == 0 else prev_frame_length
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
......
......@@ -34,7 +34,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.shift = 1
self.num_train_timesteps = 1000
self.step_index = None
......@@ -94,7 +93,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
class ConsistencyModelScheduler(EulerSchedulerTimestepFix):
def step_post(self):
logger.info(f"Step index: {self.step_index}, self.timestep: {self.timesteps[self.step_index]}")
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment