Unverified Commit e222246b authored by hlky's avatar hlky Committed by GitHub
Browse files

Fix sigma_last with use_flow_sigmas (#10267)

parent 83709d5a
...@@ -289,6 +289,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -289,6 +289,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = 1.0 - alphas sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy() timesteps = (sigmas * self.config.num_train_timesteps).copy()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
......
...@@ -291,14 +291,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -291,14 +291,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_exponential_sigmas: elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_beta_sigmas: elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_flow_sigmas: elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy() timesteps = (sigmas * self.config.num_train_timesteps).copy()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = ( sigma_max = (
......
...@@ -318,6 +318,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -318,6 +318,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
sigmas = 1.0 - alphas sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy() timesteps = (sigmas * self.config.num_train_timesteps).copy()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
......
...@@ -381,6 +381,15 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -381,6 +381,15 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = 1.0 - alphas sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy() timesteps = (sigmas * self.config.num_train_timesteps).copy()
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min": if self.config.final_sigmas_type == "sigma_min":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment