Unverified Commit 6e8e1ed7 authored by Nipun Jindal's avatar Nipun Jindal Committed by GitHub
Browse files

[2905]: Add Karras pattern to discrete euler (#2956)



* [2905]: Add Karras pattern to discrete euler

* [2905]: Add Karras pattern to discrete euler

* Review comments

* Review comments

* Review comments

* Review comments

---------
Co-authored-by: default avatarnjindal <njindal@adobe.com>
parent 37b359b2
...@@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
interpolation_type (`str`, default `"linear"`, optional): interpolation_type (`str`, default `"linear"`, optional):
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
[`"linear"`, `"log_linear"`]. [`"linear"`, `"log_linear"`].
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -118,6 +122,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -118,6 +122,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
interpolation_type: str = "linear", interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -149,6 +154,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -149,6 +154,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps) self.timesteps = torch.from_numpy(timesteps)
self.is_scale_input_called = False self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
...@@ -187,6 +193,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -187,6 +193,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.config.interpolation_type == "linear": if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
...@@ -198,6 +205,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -198,6 +205,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
" 'linear' or 'log_linear'" " 'linear' or 'log_linear'"
) )
if self.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"): if str(device).startswith("mps"):
...@@ -206,6 +217,43 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -206,6 +217,43 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
else: else:
self.timesteps = torch.from_numpy(timesteps).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device)
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, self.num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
......
...@@ -117,3 +117,30 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -117,3 +117,30 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 10.0807) < 1e-2 assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3 assert abs(result_mean.item() - 0.0131) < 1e-3
def test_full_loop_device_karras_sigmas(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 124.52299499511719) < 1e-2
assert abs(result_mean.item() - 0.16213932633399963) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment