Commit f2e1def0 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

Update scheduler.py for wan22_moe_distill (#274)

parent 6c2f36de
...@@ -71,9 +71,7 @@ class Wan22StepDistillScheduler(WanStepDistillScheduler): ...@@ -71,9 +71,7 @@ class Wan22StepDistillScheduler(WanStepDistillScheduler):
flow_pred = self.noise_pred.to(torch.float32) flow_pred = self.noise_pred.to(torch.float32)
sigma = self.sigmas[self.step_index].item() sigma = self.sigmas[self.step_index].item()
noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred
if self.step_index < self.boundary_step_index:
noisy_image_or_video = noisy_image_or_video / self.sigma_boundary
if self.step_index < self.infer_steps - 1: if self.step_index < self.infer_steps - 1:
sigma = self.sigmas[self.step_index + 1].item() sigma_n = self.sigmas[self.step_index + 1].item()
noisy_image_or_video = self.add_noise(noisy_image_or_video, torch.randn_like(noisy_image_or_video), self.sigmas[self.step_index + 1].item()) noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n
self.latents = noisy_image_or_video.to(self.latents.dtype) self.latents = noisy_image_or_video.to(self.latents.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment