Unverified Commit 6dd3871a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix DPM single (#3413)



* Fix DPM single

* add test

* fix one more bug

* Apply suggestions from code review
Co-authored-by: default avatarStAlKeR7779 <stalkek7779@yandex.ru>

---------
Co-authored-by: default avatarStAlKeR7779 <stalkek7779@yandex.ru>
parent 51843fd7
......@@ -21,9 +21,13 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
......@@ -251,7 +255,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device)
self.model_outputs = [None] * self.config.solver_order
self.sample = None
self.orders = self.get_order_list(num_inference_steps)
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
logger.warn(
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`."
)
self.register_to_config(lower_order_final=True)
self.order_list = self.get_order_list(num_inference_steps)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
......@@ -597,6 +608,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs[-1] = model_output
order = self.order_list[step_index]
# For img2img denoising might start with order>1 which is not possible
# In this case make sure that the first two steps are both order=1
while self.model_outputs[-order] is None:
order -= 1
# For single-step solvers, we use the initial value at each time with order = 1.
if order == 1:
self.sample = sample
......
......@@ -116,6 +116,22 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
return sample
def test_full_uneven_loop(self):
scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config())
num_inference_steps = 50
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
# make sure that the first t is uneven
for i, t in enumerate(scheduler.timesteps[3:]):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2574) < 1e-3
def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment