scheduling_dpmsolver_singlestep.py 43.6 KB
Newer Older
1
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver

import math
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
24
from ..utils import deprecate, logging
Kashif Rasul's avatar
Kashif Rasul committed
25
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
26
27


Patrick von Platen's avatar
Patrick von Platen committed
28
29
30
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


31
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
32
33
34
35
36
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
37
38
39
40
41
42
43
44
45
46
47
48
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
49
50
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
51
52
53
54

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
YiYi Xu's avatar
YiYi Xu committed
55
    if alpha_transform_type == "cosine":
56

YiYi Xu's avatar
YiYi Xu committed
57
58
59
60
61
62
63
64
65
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
66
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
67
68
69
70
71

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
72
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
73
74
75
76
77
    return torch.tensor(betas, dtype=torch.float32)


class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
    """
78
    `DPMSolverSinglestepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
79

80
81
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
82
83

    Args:
84
85
86
87
88
89
90
91
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
92
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
93
94
95
96
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        solver_order (`int`, defaults to 2):
            The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
97
            sampling, and `solver_order=3` for unconditional sampling.
98
99
100
101
102
103
104
105
106
107
108
109
110
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
            `algorithm_type="dpmsolver++"`.
        algorithm_type (`str`, defaults to `dpmsolver++`):
111
            Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
112
113
114
115
116
117
118
119
120
121
            `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
            paper, and the `dpmsolver++` type implements the algorithms in the
            [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
            `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
        solver_type (`str`, defaults to `midpoint`):
            Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
            sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
        lower_order_final (`bool`, defaults to `True`):
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
122
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
123
124
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
125
126
127
        final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
            The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
            is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
128
129
130
        lambda_min_clipped (`float`, defaults to `-inf`):
            Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
            cosine (`squaredcos_cap_v2`) noise schedule.
131
        variance_type (`str`, *optional*):
132
133
            Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
            contains the predicted Gaussian variance.
134
135
    """

Kashif Rasul's avatar
Kashif Rasul committed
136
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
        solver_order: int = 2,
        prediction_type: str = "epsilon",
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        algorithm_type: str = "dpmsolver++",
        solver_type: str = "midpoint",
154
        lower_order_final: bool = False,
155
        use_karras_sigmas: Optional[bool] = False,
156
        final_sigmas_type: Optional[str] = "zero",  # "zero", "sigma_min"
157
158
        lambda_min_clipped: float = -float("inf"),
        variance_type: Optional[str] = None,
159
    ):
160
161
162
163
        if algorithm_type == "dpmsolver":
            deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
            deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)

164
165
166
167
168
169
        if trained_betas is not None:
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
170
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
171
172
173
174
175
176
177
178
179
180
181
182
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        # Currently we only support VP-type noise schedule
        self.alpha_t = torch.sqrt(self.alphas_cumprod)
        self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
        self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
183
        self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
184
185
186
187
188
189

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # settings for DPM-Solver
        if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
190
            if algorithm_type == "deis":
191
                self.register_to_config(algorithm_type="dpmsolver++")
192
193
            else:
                raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
194
        if solver_type not in ["midpoint", "heun"]:
195
            if solver_type in ["logrho", "bh1", "bh2"]:
196
                self.register_to_config(solver_type="midpoint")
197
198
            else:
                raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
199

200
201
202
203
204
        if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
            raise ValueError(
                f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
            )

205
206
207
208
209
210
211
        # setable values
        self.num_inference_steps = None
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
        self.model_outputs = [None] * solver_order
        self.sample = None
        self.order_list = self.get_order_list(num_train_timesteps)
212
        self._step_index = None
213
        self._begin_index = None
214
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
215
216
217
218
219
220
221

    def get_order_list(self, num_inference_steps: int) -> List[int]:
        """
        Computes the solver order at each time step.

        Args:
            num_inference_steps (`int`):
222
                The number of diffusion steps used when generating samples with a pre-trained model.
223
224
        """
        steps = num_inference_steps
225
        order = self.config.solver_order
226
227
        if order > 3:
            raise ValueError("Order > 3 is not supported by this scheduler")
228
        if self.config.lower_order_final:
229
230
231
232
233
234
235
236
237
            if order == 3:
                if steps % 3 == 0:
                    orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1]
                elif steps % 3 == 1:
                    orders = [1, 2, 3] * (steps // 3) + [1]
                else:
                    orders = [1, 2, 3] * (steps // 3) + [1, 2]
            elif order == 2:
                if steps % 2 == 0:
238
                    orders = [1, 2] * (steps // 2 - 1) + [1, 1]
239
240
241
242
243
244
245
246
247
248
249
250
251
                else:
                    orders = [1, 2] * (steps // 2) + [1]
            elif order == 1:
                orders = [1] * steps
        else:
            if order == 3:
                orders = [1, 2, 3] * (steps // 3)
            elif order == 2:
                orders = [1, 2] * (steps // 2)
            elif order == 1:
                orders = [1] * steps
        return orders

252
253
254
    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
255
        The index counter for current timestep. It will increase 1 after each scheduler step.
256
257
258
        """
        return self._step_index

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

277
278
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
        """
279
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
280
281
282

        Args:
            num_inference_steps (`int`):
283
284
285
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
286
287
        """
        self.num_inference_steps = num_inference_steps
288
289
        # Clipping the minimum of all lambda(t) for numerical stability.
        # This is critical for cosine (squaredcos_cap_v2) noise schedule.
Patrick von Platen's avatar
Patrick von Platen committed
290
        clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
291
        timesteps = (
292
            np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
293
294
295
296
            .round()[::-1][:-1]
            .copy()
            .astype(np.int64)
        )
297

298
299
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        if self.config.use_karras_sigmas:
300
            log_sigmas = np.log(sigmas)
301
            sigmas = np.flip(sigmas).copy()
302
303
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
304
305
        else:
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
306
307

        if self.config.final_sigmas_type == "sigma_min":
308
            sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
309
310
311
312
313
314
315
        elif self.config.final_sigmas_type == "zero":
            sigma_last = 0
        else:
            raise ValueError(
                f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}"
            )
        sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
316

317
        self.sigmas = torch.from_numpy(sigmas).to(device=device)
318

319
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
320
321
        self.model_outputs = [None] * self.config.solver_order
        self.sample = None
Patrick von Platen's avatar
Patrick von Platen committed
322
323

        if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
324
            logger.warning(
325
                "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`."
Patrick von Platen's avatar
Patrick von Platen committed
326
327
328
            )
            self.register_to_config(lower_order_final=True)

329
        if not self.config.lower_order_final and self.config.final_sigmas_type == "zero":
330
            logger.warning(
331
332
333
334
                " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
            )
            self.register_to_config(lower_order_final=True)

Patrick von Platen's avatar
Patrick von Platen committed
335
        self.order_list = self.get_order_list(num_inference_steps)
336

337
338
        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
339
        self._begin_index = None
340
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
341

342
343
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
    def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
344
345
346
347
348
349
350
351
352
353
        """
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://arxiv.org/abs/2205.11487
        """
        dtype = sample.dtype
354
        batch_size, channels, *remaining_dims = sample.shape
355
356
357
358
359

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
360
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
361
362
363
364
365
366
367
368
369
370

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

371
        sample = sample.reshape(batch_size, channels, *remaining_dims)
372
373
374
        sample = sample.to(dtype)

        return sample
375

376
377
378
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
379
        log_sigma = np.log(np.maximum(sigma, 1e-10))
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

400
401
402
403
404
405
406
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
    def _sigma_to_alpha_sigma_t(self, sigma):
        alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
        sigma_t = sigma * alpha_t

        return alpha_t, sigma_t

407
408
409
410
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
    def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
        """Constructs the noise schedule of Karras et al. (2022)."""

Suraj Patil's avatar
Suraj Patil committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
425
426
427
428
429
430
431
432

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

433
    def convert_model_output(
434
435
436
437
438
        self,
        model_output: torch.FloatTensor,
        *args,
        sample: torch.FloatTensor = None,
        **kwargs,
439
440
    ) -> torch.FloatTensor:
        """
441
442
443
444
445
        Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
        designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
        integral of the data prediction model.

        <Tip>
446

447
448
        The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
        prediction and data prediction models.
449

450
        </Tip>
451
452

        Args:
453
454
            model_output (`torch.FloatTensor`):
                The direct output from the learned diffusion model.
455
            sample (`torch.FloatTensor`):
456
                A current instance of a sample created by the diffusion process.
457
458

        Returns:
459
460
            `torch.FloatTensor`:
                The converted model output.
461
        """
462
463
464
465
466
467
468
469
470
471
472
473
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
                raise ValueError("missing `sample` as a required keyward argument")
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )
474
475
476
        # DPM-Solver++ needs to solve an integral of the data prediction model.
        if self.config.algorithm_type == "dpmsolver++":
            if self.config.prediction_type == "epsilon":
477
478
479
                # DPM-Solver and DPM-Solver++ only need the "mean" output.
                if self.config.variance_type in ["learned_range"]:
                    model_output = model_output[:, :3]
480
481
                sigma = self.sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
482
483
484
485
                x0_pred = (sample - sigma_t * model_output) / alpha_t
            elif self.config.prediction_type == "sample":
                x0_pred = model_output
            elif self.config.prediction_type == "v_prediction":
486
487
                sigma = self.sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
488
489
490
491
492
493
494
495
                x0_pred = alpha_t * sample - sigma_t * model_output
            else:
                raise ValueError(
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                    " `v_prediction` for the DPMSolverSinglestepScheduler."
                )

            if self.config.thresholding:
496
497
                x0_pred = self._threshold_sample(x0_pred)

498
499
500
501
            return x0_pred
        # DPM-Solver needs to solve an integral of the noise prediction model.
        elif self.config.algorithm_type == "dpmsolver":
            if self.config.prediction_type == "epsilon":
502
503
504
                # DPM-Solver and DPM-Solver++ only need the "mean" output.
                if self.config.variance_type in ["learned_range"]:
                    model_output = model_output[:, :3]
505
506
                return model_output
            elif self.config.prediction_type == "sample":
507
508
                sigma = self.sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
509
510
511
                epsilon = (sample - alpha_t * model_output) / sigma_t
                return epsilon
            elif self.config.prediction_type == "v_prediction":
512
513
                sigma = self.sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
514
515
516
517
518
519
520
521
522
523
524
                epsilon = alpha_t * model_output + sigma_t * sample
                return epsilon
            else:
                raise ValueError(
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                    " `v_prediction` for the DPMSolverSinglestepScheduler."
                )

    def dpm_solver_first_order_update(
        self,
        model_output: torch.FloatTensor,
525
526
527
        *args,
        sample: torch.FloatTensor = None,
        **kwargs,
528
529
    ) -> torch.FloatTensor:
        """
530
        One step for the first-order DPMSolver (equivalent to DDIM).
531
532

        Args:
533
534
535
536
537
538
            model_output (`torch.FloatTensor`):
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
539
            sample (`torch.FloatTensor`):
540
                A current instance of a sample created by the diffusion process.
541
542

        Returns:
543
544
            `torch.FloatTensor`:
                The sample tensor at the previous timestep.
545
        """
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
                raise ValueError(" missing `sample` as a required keyward argument")
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )
        sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
571
572
573
574
575
576
577
578
579
580
        h = lambda_t - lambda_s
        if self.config.algorithm_type == "dpmsolver++":
            x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
        elif self.config.algorithm_type == "dpmsolver":
            x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
        return x_t

    def singlestep_dpm_solver_second_order_update(
        self,
        model_output_list: List[torch.FloatTensor],
581
582
583
        *args,
        sample: torch.FloatTensor = None,
        **kwargs,
584
585
    ) -> torch.FloatTensor:
        """
586
587
        One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
        time `timestep_list[-2]`.
588
589
590

        Args:
            model_output_list (`List[torch.FloatTensor]`):
591
592
593
594
595
                The direct outputs from learned diffusion model at current and latter timesteps.
            timestep (`int`):
                The current and latter discrete timestep in the diffusion chain.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
596
            sample (`torch.FloatTensor`):
597
                A current instance of a sample created by the diffusion process.
598
599

        Returns:
600
601
            `torch.FloatTensor`:
                The sample tensor at the previous timestep.
602
        """
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
                raise ValueError(" missing `sample` as a required keyward argument")
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )
        sigma_t, sigma_s0, sigma_s1 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
        )

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
        lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)

637
        m0, m1 = model_output_list[-1], model_output_list[-2]
638

639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
        r0 = h_0 / h
        D0, D1 = m1, (1.0 / r0) * (m0 - m1)
        if self.config.algorithm_type == "dpmsolver++":
            # See https://arxiv.org/abs/2211.01095 for detailed derivations
            if self.config.solver_type == "midpoint":
                x_t = (
                    (sigma_t / sigma_s1) * sample
                    - (alpha_t * (torch.exp(-h) - 1.0)) * D0
                    - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (sigma_t / sigma_s1) * sample
                    - (alpha_t * (torch.exp(-h) - 1.0)) * D0
                    + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
                )
        elif self.config.algorithm_type == "dpmsolver":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            if self.config.solver_type == "midpoint":
                x_t = (
                    (alpha_t / alpha_s1) * sample
                    - (sigma_t * (torch.exp(h) - 1.0)) * D0
                    - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (alpha_t / alpha_s1) * sample
                    - (sigma_t * (torch.exp(h) - 1.0)) * D0
                    - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
                )
        return x_t

    def singlestep_dpm_solver_third_order_update(
        self,
        model_output_list: List[torch.FloatTensor],
675
676
677
        *args,
        sample: torch.FloatTensor = None,
        **kwargs,
678
679
    ) -> torch.FloatTensor:
        """
680
681
        One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
        time `timestep_list[-3]`.
682
683
684

        Args:
            model_output_list (`List[torch.FloatTensor]`):
685
686
687
688
689
                The direct outputs from learned diffusion model at current and latter timesteps.
            timestep (`int`):
                The current and latter discrete timestep in the diffusion chain.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
690
            sample (`torch.FloatTensor`):
691
                A current instance of a sample created by diffusion process.
692
693

        Returns:
694
695
            `torch.FloatTensor`:
                The sample tensor at the previous timestep.
696
        """
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723

        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
                raise ValueError(" missing`sample` as a required keyward argument")
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
            self.sigmas[self.step_index - 2],
724
        )
725
726
727
728
729
730
731
732
733
734
735
736
737

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
        alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
        lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
        lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)

        m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]

738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
        h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
        r0, r1 = h_0 / h, h_1 / h
        D0 = m2
        D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2)
        D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1)
        D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1)
        if self.config.algorithm_type == "dpmsolver++":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            if self.config.solver_type == "midpoint":
                x_t = (
                    (sigma_t / sigma_s2) * sample
                    - (alpha_t * (torch.exp(-h) - 1.0)) * D0
                    + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (sigma_t / sigma_s2) * sample
                    - (alpha_t * (torch.exp(-h) - 1.0)) * D0
                    + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
                    - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
                )
        elif self.config.algorithm_type == "dpmsolver":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            if self.config.solver_type == "midpoint":
                x_t = (
                    (alpha_t / alpha_s2) * sample
                    - (sigma_t * (torch.exp(h) - 1.0)) * D0
                    - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (alpha_t / alpha_s2) * sample
                    - (sigma_t * (torch.exp(h) - 1.0)) * D0
                    - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
                    - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
                )
        return x_t

    def singlestep_dpm_solver_update(
        self,
        model_output_list: List[torch.FloatTensor],
779
780
781
782
        *args,
        sample: torch.FloatTensor = None,
        order: int = None,
        **kwargs,
783
784
    ) -> torch.FloatTensor:
        """
785
        One step for the singlestep DPMSolver.
786
787
788

        Args:
            model_output_list (`List[torch.FloatTensor]`):
789
790
791
792
793
                The direct outputs from learned diffusion model at current and latter timesteps.
            timestep (`int`):
                The current and latter discrete timestep in the diffusion chain.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
794
            sample (`torch.FloatTensor`):
795
                A current instance of a sample created by diffusion process.
796
            order (`int`):
797
                The solver order at this step.
798
799

        Returns:
800
801
            `torch.FloatTensor`:
                The sample tensor at the previous timestep.
802
        """
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
                raise ValueError(" missing`sample` as a required keyward argument")
        if order is None:
            if len(args) > 3:
                order = args[3]
            else:
                raise ValueError(" missing `order` as a required keyward argument")
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

829
        if order == 1:
830
            return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
831
        elif order == 2:
832
            return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
833
        elif order == 3:
834
            return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
835
836
837
        else:
            raise ValueError(f"Order must be 1, 2, 3, got {order}")

838
839
840
841
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
842

843
        index_candidates = (schedule_timesteps == timestep).nonzero()
844
845
846
847
848
849
850
851
852
853
854
855

        if len(index_candidates) == 0:
            step_index = len(self.timesteps) - 1
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        elif len(index_candidates) > 1:
            step_index = index_candidates[1].item()
        else:
            step_index = index_candidates[0].item()

856
857
858
859
860
861
862
863
864
865
866
867
868
869
        return step_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
    def _init_step_index(self, timestep):
        """
        Initialize the step_index counter for the scheduler.
        """

        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
870

871
872
873
874
875
876
877
878
    def step(
        self,
        model_output: torch.FloatTensor,
        timestep: int,
        sample: torch.FloatTensor,
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
879
880
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the singlestep DPMSolver.
881
882

        Args:
883
884
885
886
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
887
            sample (`torch.FloatTensor`):
888
889
890
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
891
892

        Returns:
893
894
895
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
896
897
898
899
900
901
902

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

903
904
        if self.step_index is None:
            self._init_step_index(timestep)
905

906
        model_output = self.convert_model_output(model_output, sample=sample)
907
908
909
910
        for i in range(self.config.solver_order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
        self.model_outputs[-1] = model_output

911
        order = self.order_list[self.step_index]
Patrick von Platen's avatar
Patrick von Platen committed
912
913
914
915
916
917

        #  For img2img denoising might start with order>1 which is not possible
        #  In this case make sure that the first two steps are both order=1
        while self.model_outputs[-order] is None:
            order -= 1

918
919
920
921
        # For single-step solvers, we use the initial value at each time with order = 1.
        if order == 1:
            self.sample = sample

922
923
924
925
        prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)

        # upon completion increase step index by one
        self._step_index += 1
926
927
928
929
930
931
932
933
934
935
936
937

        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

    def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
938
939
            sample (`torch.FloatTensor`):
                The input sample.
940
941

        Returns:
942
943
            `torch.FloatTensor`:
                A scaled input sample.
944
945
946
        """
        return sample

947
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
948
949
950
951
    def add_noise(
        self,
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
952
        timesteps: torch.IntTensor,
953
    ) -> torch.FloatTensor:
954
955
956
957
958
959
960
961
962
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
            schedule_timesteps = self.timesteps.to(original_samples.device)
            timesteps = timesteps.to(original_samples.device)
963

964
965
966
967
968
        # begin_index is None when the scheduler is used for training
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
        else:
            step_indices = [self.begin_index] * timesteps.shape[0]
969

970
971
972
        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)
973

974
975
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
976
977
978
979
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps