Unverified Commit 80c00e54 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add `use_karras_sigmas` to `KDPM2DiscreteScheduler` and `KDPM2AncestralDiscreteScheduler` (#5111)




---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 2badddfd
...@@ -609,6 +609,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -609,6 +609,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
sampler_kwargs["noise_sampler"] = noise_sampler sampler_kwargs["noise_sampler"] = noise_sampler
if "generator" in inspect.signature(self.sampler).parameters:
sampler_kwargs["generator"] = generator
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
if not output_type == "latent": if not output_type == "latent":
......
...@@ -89,6 +89,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -89,6 +89,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`. `linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*): trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
...@@ -113,6 +116,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,6 +116,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.012, beta_end: float = 0.012,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
...@@ -243,9 +247,15 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -243,9 +247,15 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
) )
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device) sigmas = torch.from_numpy(sigmas).to(device=device)
...@@ -269,7 +279,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -269,7 +279,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
timesteps = torch.from_numpy(timesteps).to(device) timesteps = torch.from_numpy(timesteps).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
)
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
...@@ -282,29 +298,44 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -282,29 +298,44 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None self._step_index = None
def sigma_to_t(self, sigma): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma # get log sigma
log_sigma = sigma.log() log_sigma = np.log(sigma)
# get distribution # get distribution
dists = log_sigma - self.log_sigmas[:, None] dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range # get sigmas range
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1 high_idx = low_idx + 1
low = self.log_sigmas[low_idx] low = log_sigmas[low_idx]
high = self.log_sigmas[high_idx] high = log_sigmas[high_idx]
# interpolate sigmas # interpolate sigmas
w = (low - log_sigma) / (low - high) w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1) w = np.clip(w, 0, 1)
# transform interpolation to time range # transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape) t = t.reshape(sigma.shape)
return t return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
@property @property
def state_in_first_order(self): def state_in_first_order(self):
return self.sample is None return self.sample is None
......
...@@ -88,6 +88,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -88,6 +88,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`. `linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*): trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
...@@ -112,6 +115,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,6 +115,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.012, beta_end: float = 0.012,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
...@@ -243,9 +247,14 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -243,9 +247,14 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
) )
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device) sigmas = torch.from_numpy(sigmas).to(device=device)
...@@ -260,7 +269,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -260,7 +269,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = torch.from_numpy(timesteps).to(device) timesteps = torch.from_numpy(timesteps).to(device)
# interpolate timesteps # interpolate timesteps
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
)
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
...@@ -273,29 +287,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -273,29 +287,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None self._step_index = None
def sigma_to_t(self, sigma):
# get log sigma
log_sigma = sigma.log()
# get distribution
dists = log_sigma - self.log_sigmas[:, None]
# get sigmas range
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = self.log_sigmas[low_idx]
high = self.log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
return t
@property @property
def state_in_first_order(self): def state_in_first_order(self):
return self.sample is None return self.sample is None
...@@ -318,6 +309,44 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -318,6 +309,44 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = step_index.item() self._step_index = step_index.item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment