scheduling_dpmsolver_multistep.py 54 KB
Newer Older
1
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver

import math
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
24
from ..utils import deprecate, is_scipy_available
Dhruv Nair's avatar
Dhruv Nair committed
25
from ..utils.torch_utils import randn_tensor
Kashif Rasul's avatar
Kashif Rasul committed
26
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
28


29
30
31
32
if is_scipy_available():
    import scipy.stats


33
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
34
35
36
37
38
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
39
40
41
42
43
44
45
46
47
48
49
50
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
51
52
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
53
54
55
56

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
YiYi Xu's avatar
YiYi Xu committed
57
    if alpha_transform_type == "cosine":
58

YiYi Xu's avatar
YiYi Xu committed
59
60
61
62
63
64
65
66
67
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
68
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
69
70
71
72
73

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
74
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
75
76
77
    return torch.tensor(betas, dtype=torch.float32)


78
79
80
81
82
83
84
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
    """
    Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)


    Args:
85
        betas (`torch.Tensor`):
86
87
88
            the betas that the scheduler is being initialized with.

    Returns:
89
        `torch.Tensor`: rescaled betas with zero terminal SNR
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


115
116
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
    """
117
    `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
118

119
120
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
121
122

    Args:
123
124
125
126
127
128
129
130
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
131
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
132
133
134
135
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        solver_order (`int`, defaults to 2):
            The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
136
            sampling, and `solver_order=3` for unconditional sampling.
137
138
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
139
140
            `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
            `algorithm_type="dpmsolver++"`.
        algorithm_type (`str`, defaults to `dpmsolver++`):
            Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
            `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
            paper, and the `dpmsolver++` type implements the algorithms in the
            [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
            `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
        solver_type (`str`, defaults to `midpoint`):
            Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
            sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
        lower_order_final (`bool`, defaults to `True`):
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
161
162
163
164
        euler_at_final (`bool`, defaults to `False`):
            Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
            richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
            steps, but sometimes may result in blurring.
165
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
166
167
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
168
169
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
170
171
172
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
173
174
175
176
        use_lu_lambdas (`bool`, *optional*, defaults to `False`):
            Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
            the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
            `lambda(t)`.
177
178
179
180
        use_flow_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
        flow_shift (`float`, *optional*, defaults to 1.0):
            The shift value for the timestep schedule for flow matching.
181
        final_sigmas_type (`str`, defaults to `"zero"`):
182
183
            The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
            sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
184
185
186
        lambda_min_clipped (`float`, defaults to `-inf`):
            Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
            cosine (`squaredcos_cap_v2`) noise schedule.
187
        variance_type (`str`, *optional*):
188
189
190
191
192
193
            Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
            contains the predicted Gaussian variance.
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
194
            An offset added to the inference steps, as required by some model families.
195
196
197
198
        rescale_betas_zero_snr (`bool`, defaults to `False`):
            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
            dark samples instead of limiting it to samples with medium brightness. Loosely related to
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
199
200
    """

Kashif Rasul's avatar
Kashif Rasul committed
201
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
202
    order = 1
203
204
205
206
207
208
209
210

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
211
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
212
        solver_order: int = 2,
213
        prediction_type: str = "epsilon",
214
215
216
217
218
219
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        algorithm_type: str = "dpmsolver++",
        solver_type: str = "midpoint",
        lower_order_final: bool = True,
220
        euler_at_final: bool = False,
221
        use_karras_sigmas: Optional[bool] = False,
222
        use_exponential_sigmas: Optional[bool] = False,
223
        use_beta_sigmas: Optional[bool] = False,
224
        use_lu_lambdas: Optional[bool] = False,
225
226
        use_flow_sigmas: Optional[bool] = False,
        flow_shift: Optional[float] = 1.0,
227
        final_sigmas_type: Optional[str] = "zero",  # "zero", "sigma_min"
228
229
        lambda_min_clipped: float = -float("inf"),
        variance_type: Optional[str] = None,
230
231
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
232
        rescale_betas_zero_snr: bool = False,
233
    ):
234
235
236
237
238
239
        if self.config.use_beta_sigmas and not is_scipy_available():
            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
240
241
242
243
        if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
            deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
            deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)

244
        if trained_betas is not None:
245
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
246
247
248
249
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
250
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
251
252
253
254
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
255
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
256

257
258
259
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

260
261
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
262
263
264
265
266
267

        if rescale_betas_zero_snr:
            # Close to 0 without being 0 so first sigma is not inf
            # FP16 smallest positive subnormal works well here
            self.alphas_cumprod[-1] = 2**-24

268
269
270
271
        # Currently we only support VP-type noise schedule
        self.alpha_t = torch.sqrt(self.alphas_cumprod)
        self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
        self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
272
        self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
273
274
275
276
277

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # settings for DPM-Solver
278
        if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
279
            if algorithm_type == "deis":
280
                self.register_to_config(algorithm_type="dpmsolver++")
281
            else:
282
                raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
283

284
        if solver_type not in ["midpoint", "heun"]:
285
            if solver_type in ["logrho", "bh1", "bh2"]:
286
                self.register_to_config(solver_type="midpoint")
287
            else:
288
                raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
289

290
291
292
293
294
        if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
            raise ValueError(
                f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
            )

295
296
297
298
299
300
        # setable values
        self.num_inference_steps = None
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
        self.model_outputs = [None] * solver_order
        self.lower_order_nums = 0
301
        self._step_index = None
302
        self._begin_index = None
303
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
304
305
306
307

    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
308
        The index counter for current timestep. It will increase 1 after each scheduler step.
309
310
        """
        return self._step_index
311

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

329
330
331
332
333
334
    def set_timesteps(
        self,
        num_inference_steps: int = None,
        device: Union[str, torch.device] = None,
        timesteps: Optional[List[int]] = None,
    ):
335
        """
336
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
337
338
339

        Args:
            num_inference_steps (`int`):
340
341
342
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
343
344
345
346
            timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
                based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
                must be `None`, and `timestep_spacing` attribute will be ignored.
347
        """
348
349
350
351
352
353
354
355
        if num_inference_steps is None and timesteps is None:
            raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
        if num_inference_steps is not None and timesteps is not None:
            raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
        if timesteps is not None and self.config.use_karras_sigmas:
            raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
        if timesteps is not None and self.config.use_lu_lambdas:
            raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
356
357
        if timesteps is not None and self.config.use_exponential_sigmas:
            raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
358
359
        if timesteps is not None and self.config.use_beta_sigmas:
            raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
360
361
362

        if timesteps is not None:
            timesteps = np.array(timesteps).astype(np.int64)
363
        else:
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
            # Clipping the minimum of all lambda(t) for numerical stability.
            # This is critical for cosine (squaredcos_cap_v2) noise schedule.
            clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
            last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()

            # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
            if self.config.timestep_spacing == "linspace":
                timesteps = (
                    np.linspace(0, last_timestep - 1, num_inference_steps + 1)
                    .round()[::-1][:-1]
                    .copy()
                    .astype(np.int64)
                )
            elif self.config.timestep_spacing == "leading":
                step_ratio = last_timestep // (num_inference_steps + 1)
                # creates integer timesteps by multiplying by ratio
                # casting to int to avoid issues when num_inference_step is power of 3
                timesteps = (
                    (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
                )
                timesteps += self.config.steps_offset
            elif self.config.timestep_spacing == "trailing":
                step_ratio = self.config.num_train_timesteps / num_inference_steps
                # creates integer timesteps by multiplying by ratio
                # casting to int to avoid issues when num_inference_step is power of 3
                timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
                timesteps -= 1
            else:
                raise ValueError(
                    f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
                )
395

396
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
397
398
        log_sigmas = np.log(sigmas)

399
        if self.config.use_karras_sigmas:
400
            sigmas = np.flip(sigmas).copy()
401
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
402
403
404
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
            if self.config.beta_schedule != "squaredcos_cap_v2":
                timesteps = timesteps.round()
405
406
407
408
        elif self.config.use_lu_lambdas:
            lambdas = np.flip(log_sigmas.copy())
            lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
            sigmas = np.exp(lambdas)
409
410
411
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
            if self.config.beta_schedule != "squaredcos_cap_v2":
                timesteps = timesteps.round()
412
        elif self.config.use_exponential_sigmas:
413
414
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
415
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
416
        elif self.config.use_beta_sigmas:
417
418
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
419
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
420
421
422
        elif self.config.use_flow_sigmas:
            alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
            sigmas = 1.0 - alphas
hlky's avatar
hlky committed
423
            sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
424
            timesteps = (sigmas * self.config.num_train_timesteps).copy()
425
426
        else:
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
427
428

        if self.config.final_sigmas_type == "sigma_min":
429
            sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
430
431
432
433
434
435
436
437
        elif self.config.final_sigmas_type == "zero":
            sigma_last = 0
        else:
            raise ValueError(
                f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
            )

        sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
438

439
440
        self.sigmas = torch.from_numpy(sigmas)
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
441
442
443

        self.num_inference_steps = len(timesteps)

444
445
446
447
448
        self.model_outputs = [
            None,
        ] * self.config.solver_order
        self.lower_order_nums = 0

449
450
        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
451
        self._begin_index = None
452
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
453

454
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
455
    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
456
457
458
459
460
461
462
463
464
465
        """
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://arxiv.org/abs/2205.11487
        """
        dtype = sample.dtype
466
        batch_size, channels, *remaining_dims = sample.shape
467
468
469
470
471

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
472
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
473
474
475
476
477
478
479
480
481
482

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

483
        sample = sample.reshape(batch_size, channels, *remaining_dims)
484
485
486
        sample = sample.to(dtype)

        return sample
487

488
489
490
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
491
        log_sigma = np.log(np.maximum(sigma, 1e-10))
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

512
    def _sigma_to_alpha_sigma_t(self, sigma):
513
514
515
516
517
518
        if self.config.use_flow_sigmas:
            alpha_t = 1 - sigma
            sigma_t = sigma
        else:
            alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
            sigma_t = sigma * alpha_t
519
520
521

        return alpha_t, sigma_t

522
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
523
    def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
524
525
        """Constructs the noise schedule of Karras et al. (2022)."""

Suraj Patil's avatar
Suraj Patil committed
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
540
541
542
543
544
545
546
547

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

548
    def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
549
550
551
552
553
554
555
556
557
558
559
560
        """Constructs the noise schedule of Lu et al. (2022)."""

        lambda_min: float = in_lambdas[-1].item()
        lambda_max: float = in_lambdas[0].item()

        rho = 1.0  # 1.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = lambda_min ** (1 / rho)
        max_inv_rho = lambda_max ** (1 / rho)
        lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return lambdas

561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
        """Constructs an exponential noise schedule."""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

580
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
581
582
        return sigmas

583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
        """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

604
        sigmas = np.array(
605
606
607
608
609
610
611
612
613
614
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

615
    def convert_model_output(
616
        self,
617
        model_output: torch.Tensor,
618
        *args,
619
        sample: torch.Tensor = None,
620
        **kwargs,
621
    ) -> torch.Tensor:
622
        """
623
624
625
        Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
        designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
        integral of the data prediction model.
626

627
        <Tip>
628

629
630
631
632
        The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
        prediction and data prediction models.

        </Tip>
633
634

        Args:
635
            model_output (`torch.Tensor`):
636
                The direct output from the learned diffusion model.
637
            sample (`torch.Tensor`):
638
                A current instance of a sample created by the diffusion process.
639
640

        Returns:
641
            `torch.Tensor`:
642
                The converted model output.
643
        """
644
645
646
647
648
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
649
                raise ValueError("missing `sample` as a required keyword argument")
650
651
652
653
654
655
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )
656

657
        # DPM-Solver++ needs to solve an integral of the data prediction model.
658
        if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
659
            if self.config.prediction_type == "epsilon":
660
                # DPM-Solver and DPM-Solver++ only need the "mean" output.
661
                if self.config.variance_type in ["learned", "learned_range"]:
662
                    model_output = model_output[:, :3]
663
664
                sigma = self.sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
665
                x0_pred = (sample - sigma_t * model_output) / alpha_t
666
            elif self.config.prediction_type == "sample":
667
                x0_pred = model_output
668
            elif self.config.prediction_type == "v_prediction":
669
670
                sigma = self.sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
671
                x0_pred = alpha_t * sample - sigma_t * model_output
672
673
674
            elif self.config.prediction_type == "flow_prediction":
                sigma_t = self.sigmas[self.step_index]
                x0_pred = sample - sigma_t * model_output
675
676
            else:
                raise ValueError(
677
678
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
                    "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
679
680
                )

681
            if self.config.thresholding:
682
683
                x0_pred = self._threshold_sample(x0_pred)

684
            return x0_pred
685

686
        # DPM-Solver needs to solve an integral of the noise prediction model.
687
        elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
688
            if self.config.prediction_type == "epsilon":
689
                # DPM-Solver and DPM-Solver++ only need the "mean" output.
690
691
692
693
                if self.config.variance_type in ["learned", "learned_range"]:
                    epsilon = model_output[:, :3]
                else:
                    epsilon = model_output
694
            elif self.config.prediction_type == "sample":
695
696
                sigma = self.sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
697
                epsilon = (sample - alpha_t * model_output) / sigma_t
698
            elif self.config.prediction_type == "v_prediction":
699
700
                sigma = self.sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
701
                epsilon = alpha_t * model_output + sigma_t * sample
702
703
            else:
                raise ValueError(
704
705
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                    " `v_prediction` for the DPMSolverMultistepScheduler."
706
                )
707

708
            if self.config.thresholding:
709
710
                sigma = self.sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
711
712
713
714
715
716
                x0_pred = (sample - sigma_t * epsilon) / alpha_t
                x0_pred = self._threshold_sample(x0_pred)
                epsilon = (sample - alpha_t * x0_pred) / sigma_t

            return epsilon

717
718
    def dpm_solver_first_order_update(
        self,
719
        model_output: torch.Tensor,
720
        *args,
721
722
        sample: torch.Tensor = None,
        noise: Optional[torch.Tensor] = None,
723
        **kwargs,
724
    ) -> torch.Tensor:
725
        """
726
        One step for the first-order DPMSolver (equivalent to DDIM).
727
728

        Args:
729
            model_output (`torch.Tensor`):
730
                The direct output from the learned diffusion model.
731
            sample (`torch.Tensor`):
732
                A current instance of a sample created by the diffusion process.
733
734

        Returns:
735
            `torch.Tensor`:
736
                The sample tensor at the previous timestep.
737
        """
738
739
740
741
742
743
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
744
                raise ValueError("missing `sample` as a required keyword argument")
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s = torch.log(alpha_s) - torch.log(sigma_s)

765
766
767
768
769
        h = lambda_t - lambda_s
        if self.config.algorithm_type == "dpmsolver++":
            x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
        elif self.config.algorithm_type == "dpmsolver":
            x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
770
771
772
773
774
775
776
777
778
779
780
781
782
783
        elif self.config.algorithm_type == "sde-dpmsolver++":
            assert noise is not None
            x_t = (
                (sigma_t / sigma_s * torch.exp(-h)) * sample
                + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
                + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
            )
        elif self.config.algorithm_type == "sde-dpmsolver":
            assert noise is not None
            x_t = (
                (alpha_t / alpha_s) * sample
                - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
                + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
            )
784
785
786
787
        return x_t

    def multistep_dpm_solver_second_order_update(
        self,
788
        model_output_list: List[torch.Tensor],
789
        *args,
790
791
        sample: torch.Tensor = None,
        noise: Optional[torch.Tensor] = None,
792
        **kwargs,
793
    ) -> torch.Tensor:
794
        """
795
        One step for the second-order multistep DPMSolver.
796
797

        Args:
798
            model_output_list (`List[torch.Tensor]`):
799
                The direct outputs from learned diffusion model at current and latter timesteps.
800
            sample (`torch.Tensor`):
801
                A current instance of a sample created by the diffusion process.
802
803

        Returns:
804
            `torch.Tensor`:
805
                The sample tensor at the previous timestep.
806
        """
807
808
809
810
811
812
        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
813
                raise ValueError("missing `sample` as a required keyword argument")
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s0, sigma_s1 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
        )

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
        lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)

842
        m0, m1 = model_output_list[-1], model_output_list[-2]
843

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
        h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
        r0 = h_0 / h
        D0, D1 = m0, (1.0 / r0) * (m0 - m1)
        if self.config.algorithm_type == "dpmsolver++":
            # See https://arxiv.org/abs/2211.01095 for detailed derivations
            if self.config.solver_type == "midpoint":
                x_t = (
                    (sigma_t / sigma_s0) * sample
                    - (alpha_t * (torch.exp(-h) - 1.0)) * D0
                    - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (sigma_t / sigma_s0) * sample
                    - (alpha_t * (torch.exp(-h) - 1.0)) * D0
                    + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
                )
        elif self.config.algorithm_type == "dpmsolver":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            if self.config.solver_type == "midpoint":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - (sigma_t * (torch.exp(h) - 1.0)) * D0
                    - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - (sigma_t * (torch.exp(h) - 1.0)) * D0
                    - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
                )
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
        elif self.config.algorithm_type == "sde-dpmsolver++":
            assert noise is not None
            if self.config.solver_type == "midpoint":
                x_t = (
                    (sigma_t / sigma_s0 * torch.exp(-h)) * sample
                    + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
                    + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
                    + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (sigma_t / sigma_s0 * torch.exp(-h)) * sample
                    + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
                    + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
                    + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
                )
        elif self.config.algorithm_type == "sde-dpmsolver":
            assert noise is not None
            if self.config.solver_type == "midpoint":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
                    - (sigma_t * (torch.exp(h) - 1.0)) * D1
                    + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
                    - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
                    + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
                )
907
908
909
910
        return x_t

    def multistep_dpm_solver_third_order_update(
        self,
911
        model_output_list: List[torch.Tensor],
912
        *args,
913
        sample: torch.Tensor = None,
StAlKeR7779's avatar
StAlKeR7779 committed
914
        noise: Optional[torch.Tensor] = None,
915
        **kwargs,
916
    ) -> torch.Tensor:
917
        """
918
        One step for the third-order multistep DPMSolver.
919
920

        Args:
921
            model_output_list (`List[torch.Tensor]`):
922
                The direct outputs from learned diffusion model at current and latter timesteps.
923
            sample (`torch.Tensor`):
924
                A current instance of a sample created by diffusion process.
925
926

        Returns:
927
            `torch.Tensor`:
928
                The sample tensor at the previous timestep.
929
        """
930
931
932
933
934
935
936

        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
937
                raise ValueError("missing `sample` as a required keyword argument")
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
            self.sigmas[self.step_index - 2],
957
        )
958
959
960
961
962
963
964
965
966
967
968
969
970

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
        alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
        lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
        lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)

        m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]

971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
        h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
        r0, r1 = h_0 / h, h_1 / h
        D0 = m0
        D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
        D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
        D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
        if self.config.algorithm_type == "dpmsolver++":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            x_t = (
                (sigma_t / sigma_s0) * sample
                - (alpha_t * (torch.exp(-h) - 1.0)) * D0
                + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
                - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
            )
        elif self.config.algorithm_type == "dpmsolver":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            x_t = (
                (alpha_t / alpha_s0) * sample
                - (sigma_t * (torch.exp(h) - 1.0)) * D0
                - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
                - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
            )
StAlKeR7779's avatar
StAlKeR7779 committed
993
994
995
996
997
998
999
1000
1001
        elif self.config.algorithm_type == "sde-dpmsolver++":
            assert noise is not None
            x_t = (
                (sigma_t / sigma_s0 * torch.exp(-h)) * sample
                + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
                + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
                + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
                + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
            )
1002
1003
        return x_t

1004
1005
1006
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
1007

1008
        index_candidates = (schedule_timesteps == timestep).nonzero()
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020

        if len(index_candidates) == 0:
            step_index = len(self.timesteps) - 1
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        elif len(index_candidates) > 1:
            step_index = index_candidates[1].item()
        else:
            step_index = index_candidates[0].item()

1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
        return step_index

    def _init_step_index(self, timestep):
        """
        Initialize the step_index counter for the scheduler.
        """

        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
1034

1035
1036
    def step(
        self,
1037
        model_output: torch.Tensor,
1038
        timestep: Union[int, torch.Tensor],
1039
        sample: torch.Tensor,
1040
        generator=None,
1041
        variance_noise: Optional[torch.Tensor] = None,
1042
1043
1044
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
1045
1046
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the multistep DPMSolver.
1047
1048

        Args:
1049
            model_output (`torch.Tensor`):
1050
1051
1052
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
1053
            sample (`torch.Tensor`):
1054
1055
1056
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
1057
            variance_noise (`torch.Tensor`):
1058
1059
                Alternative to generating noise with `generator` by directly providing the noise for the variance
                itself. Useful for methods such as [`LEdits++`].
1060
1061
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
1062
1063

        Returns:
1064
1065
1066
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
1067
1068
1069
1070
1071
1072
1073

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

1074
1075
1076
        if self.step_index is None:
            self._init_step_index(timestep)

1077
1078
        # Improve numerical stability for small number of steps
        lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
1079
1080
1081
            self.config.euler_at_final
            or (self.config.lower_order_final and len(self.timesteps) < 15)
            or self.config.final_sigmas_type == "zero"
1082
1083
        )
        lower_order_second = (
1084
            (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
1085
1086
        )

1087
        model_output = self.convert_model_output(model_output, sample=sample)
1088
1089
1090
1091
        for i in range(self.config.solver_order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
        self.model_outputs[-1] = model_output

1092
1093
        # Upcast to avoid precision issues when computing prev_sample
        sample = sample.to(torch.float32)
1094
        if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
1095
            noise = randn_tensor(
1096
                model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
1097
            )
1098
1099
        elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
            noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
1100
1101
1102
        else:
            noise = None

1103
        if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
1104
            prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
1105
        elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
1106
            prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
1107
        else:
StAlKeR7779's avatar
StAlKeR7779 committed
1108
            prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
1109
1110
1111
1112

        if self.lower_order_nums < self.config.solver_order:
            self.lower_order_nums += 1

1113
1114
1115
        # Cast sample back to expected dtype
        prev_sample = prev_sample.to(model_output.dtype)

1116
1117
1118
        # upon completion increase step index by one
        self._step_index += 1

1119
1120
1121
1122
1123
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

1124
    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1125
1126
1127
1128
1129
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
1130
            sample (`torch.Tensor`):
1131
                The input sample.
1132
1133

        Returns:
1134
            `torch.Tensor`:
1135
                A scaled input sample.
1136
1137
1138
1139
1140
        """
        return sample

    def add_noise(
        self,
1141
1142
        original_samples: torch.Tensor,
        noise: torch.Tensor,
1143
        timesteps: torch.IntTensor,
1144
    ) -> torch.Tensor:
1145
1146
1147
1148
1149
1150
1151
1152
1153
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
            schedule_timesteps = self.timesteps.to(original_samples.device)
            timesteps = timesteps.to(original_samples.device)
1154

1155
        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
1156
1157
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1158
1159
1160
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
1161
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
1162
            # add noise is called before first denoising step to create initial latent(img2img)
1163
            step_indices = [self.begin_index] * timesteps.shape[0]
1164

1165
1166
1167
        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)
1168

1169
1170
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
1171
1172
1173
1174
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps