Unverified Commit 3c464853 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

update scheduler (#345)

parent 131c8a46
......@@ -75,19 +75,10 @@ class Wan22StepDistillScheduler(WanStepDistillScheduler):
def step_post(self):
flow_pred = self.noise_pred.to(torch.float32)
sigma = self.sigmas[self.step_index].item()
noisy_image_or_video = self.latents.to(torch.float32) - flow_pred * sigma
# self.latent: x_t
if self.step_index < self.boundary_step_index:
# noisy_image_or_video: x_500
alpha, beta = self.calculate_alpha_beta_high(sigma)
noisy_image_or_video = (self.latents.to(torch.float32) - beta * (1 - self.sigma_bound) * flow_pred) / (alpha + beta)
if self.step_index < self.boundary_step_index - 1:
sigma_n = self.sigmas[self.step_index + 1].item()
alpha_n, beta_n = self.calculate_alpha_beta_high(sigma_n)
noisy_image_or_video = (alpha_n + beta_n) * noisy_image_or_video + (1 - self.sigma_bound) * beta_n * flow_pred
else:
# noisy_image_or_video: x_0
noisy_image_or_video = self.latents.to(torch.float32) - flow_pred * sigma
if self.step_index < self.infer_steps - 1:
sigma_n = self.sigmas[self.step_index + 1].item()
noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n
if self.step_index < self.infer_steps - 1:
sigma_n = self.sigmas[self.step_index + 1].item()
noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n
self.latents = noisy_image_or_video.to(self.latents.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment