Unverified Commit 1b6c7ea7 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[schedulers] create `self.sigmas` during __init__ (#6006)

* fix dpm
* all scheulers
parent b41f809a
...@@ -162,6 +162,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -162,6 +162,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.alpha_t = torch.sqrt(self.alphas_cumprod) self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
......
...@@ -189,6 +189,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,6 +189,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.alpha_t = torch.sqrt(self.alphas_cumprod) self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
......
...@@ -184,6 +184,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -184,6 +184,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.alpha_t = torch.sqrt(self.alphas_cumprod) self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
......
...@@ -172,6 +172,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -172,6 +172,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.alpha_t = torch.sqrt(self.alphas_cumprod) self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
......
...@@ -175,6 +175,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -175,6 +175,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.alpha_t = torch.sqrt(self.alphas_cumprod) self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment