scheduling_deis_multistep.py 38.8 KB
Newer Older
1
# Copyright 2024 FLAIR Lab and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

import math
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
25
from ..utils import deprecate, is_scipy_available
Kashif Rasul's avatar
Kashif Rasul committed
26
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
28


29
30
31
32
if is_scipy_available():
    import scipy.stats


33
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
34
35
36
37
38
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
39
40
41
42
43
44
45
46
47
48
49
50
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
51
52
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
53
54
55
56

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
YiYi Xu's avatar
YiYi Xu committed
57
    if alpha_transform_type == "cosine":
58

YiYi Xu's avatar
YiYi Xu committed
59
60
61
62
63
64
65
66
67
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
68
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
69
70
71
72
73

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
74
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
75
76
77
78
79
    return torch.tensor(betas, dtype=torch.float32)


class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
    """
80
    `DEISMultistepScheduler` is a fast high order solver for diffusion ordinary differential equations (ODEs).
81

82
83
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
84
85

    Args:
86
87
88
89
90
91
92
93
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
94
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        solver_order (`int`, defaults to 2):
            The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
            sampling, and `solver_order=3` for unconditional sampling.
        prediction_type (`str`, defaults to `epsilon`):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
        algorithm_type (`str`, defaults to `deis`):
            The algorithm type for the solver.
        lower_order_final (`bool`, defaults to `True`):
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps.
115
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
116
117
             Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
             the sigmas are determined according to a sequence of noise levels {σi}.
118
119
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
120
121
122
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
123
124
125
126
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
127
            An offset added to the inference steps, as required by some model families.
128
129
    """

Kashif Rasul's avatar
Kashif Rasul committed
130
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
        solver_order: int = 2,
        prediction_type: str = "epsilon",
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        algorithm_type: str = "deis",
        solver_type: str = "logrho",
        lower_order_final: bool = True,
149
        use_karras_sigmas: Optional[bool] = False,
150
        use_exponential_sigmas: Optional[bool] = False,
151
        use_beta_sigmas: Optional[bool] = False,
152
153
        use_flow_sigmas: Optional[bool] = False,
        flow_shift: Optional[float] = 1.0,
154
155
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
156
    ):
157
158
159
160
161
162
        if self.config.use_beta_sigmas and not is_scipy_available():
            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
163
164
165
166
167
168
        if trained_betas is not None:
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
169
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
170
171
172
173
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
174
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
175
176
177
178
179
180
181

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        # Currently we only support VP-type noise schedule
        self.alpha_t = torch.sqrt(self.alphas_cumprod)
        self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
        self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
182
        self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
183
184
185
186
187
188
189

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # settings for DEIS
        if algorithm_type not in ["deis"]:
            if algorithm_type in ["dpmsolver", "dpmsolver++"]:
190
                self.register_to_config(algorithm_type="deis")
191
            else:
192
                raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
193
194

        if solver_type not in ["logrho"]:
195
            if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
196
                self.register_to_config(solver_type="logrho")
197
            else:
198
                raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}")
199
200
201
202
203
204
205

        # setable values
        self.num_inference_steps = None
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
        self.model_outputs = [None] * solver_order
        self.lower_order_nums = 0
206
        self._step_index = None
207
        self._begin_index = None
208
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
209
210
211
212

    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
213
        The index counter for current timestep. It will increase 1 after each scheduler step.
214
215
        """
        return self._step_index
216

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

235
236
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
        """
237
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
238
239
240

        Args:
            num_inference_steps (`int`):
241
242
243
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
244
        """
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
                .round()[::-1][:-1]
                .copy()
                .astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
269

270
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
271
        log_sigmas = np.log(sigmas)
272
        if self.config.use_karras_sigmas:
273
            sigmas = np.flip(sigmas).copy()
274
275
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
276
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
277
        elif self.config.use_exponential_sigmas:
278
279
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
280
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
281
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
282
        elif self.config.use_beta_sigmas:
283
284
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
285
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
286
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
287
288
289
        elif self.config.use_flow_sigmas:
            alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
            sigmas = 1.0 - alphas
hlky's avatar
hlky committed
290
            sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
291
            timesteps = (sigmas * self.config.num_train_timesteps).copy()
292
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
293
294
295
296
        else:
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
            sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
297

298
299
        self.sigmas = torch.from_numpy(sigmas)
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
300
301
302

        self.num_inference_steps = len(timesteps)

303
304
305
306
307
        self.model_outputs = [
            None,
        ] * self.config.solver_order
        self.lower_order_nums = 0

308
309
        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
310
        self._begin_index = None
311
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
312

313
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
314
    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
315
316
317
318
319
320
321
322
323
324
        """
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://arxiv.org/abs/2205.11487
        """
        dtype = sample.dtype
325
        batch_size, channels, *remaining_dims = sample.shape
326
327
328
329
330

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
331
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
332
333
334
335
336
337
338
339
340
341

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

342
        sample = sample.reshape(batch_size, channels, *remaining_dims)
343
344
345
        sample = sample.to(dtype)

        return sample
346

347
348
349
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
350
        log_sigma = np.log(np.maximum(sigma, 1e-10))
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
    def _sigma_to_alpha_sigma_t(self, sigma):
373
374
375
376
377
378
        if self.config.use_flow_sigmas:
            alpha_t = 1 - sigma
            sigma_t = sigma
        else:
            alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
            sigma_t = sigma * alpha_t
379
380
381
382

        return alpha_t, sigma_t

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
383
    def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
384
385
        """Constructs the noise schedule of Karras et al. (2022)."""

Suraj Patil's avatar
Suraj Patil committed
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
400
401
402
403
404
405
406
407

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
        """Constructs an exponential noise schedule."""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

427
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
428
429
        return sigmas

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
        """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

451
        sigmas = np.array(
452
453
454
455
456
457
458
459
460
461
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

462
    def convert_model_output(
463
        self,
464
        model_output: torch.Tensor,
465
        *args,
466
        sample: torch.Tensor = None,
467
        **kwargs,
468
    ) -> torch.Tensor:
469
        """
470
        Convert the model output to the corresponding type the DEIS algorithm needs.
471
472

        Args:
473
            model_output (`torch.Tensor`):
474
475
476
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
477
            sample (`torch.Tensor`):
478
                A current instance of a sample created by the diffusion process.
479
480

        Returns:
481
            `torch.Tensor`:
482
                The converted model output.
483
        """
484
485
486
487
488
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
489
                raise ValueError("missing `sample` as a required keyword argument")
490
491
492
493
494
495
496
497
498
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma = self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
499
500
501
502
503
504
        if self.config.prediction_type == "epsilon":
            x0_pred = (sample - sigma_t * model_output) / alpha_t
        elif self.config.prediction_type == "sample":
            x0_pred = model_output
        elif self.config.prediction_type == "v_prediction":
            x0_pred = alpha_t * sample - sigma_t * model_output
505
506
507
        elif self.config.prediction_type == "flow_prediction":
            sigma_t = self.sigmas[self.step_index]
            x0_pred = sample - sigma_t * model_output
508
509
        else:
            raise ValueError(
510
511
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
                "`v_prediction`, or `flow_prediction` for the DEISMultistepScheduler."
512
513
514
            )

        if self.config.thresholding:
515
            x0_pred = self._threshold_sample(x0_pred)
516
517
518
519
520
521
522
523

        if self.config.algorithm_type == "deis":
            return (sample - alpha_t * x0_pred) / sigma_t
        else:
            raise NotImplementedError("only support log-rho multistep deis now")

    def deis_first_order_update(
        self,
524
        model_output: torch.Tensor,
525
        *args,
526
        sample: torch.Tensor = None,
527
        **kwargs,
528
    ) -> torch.Tensor:
529
530
531
532
        """
        One step for the first-order DEIS (equivalent to DDIM).

        Args:
533
            model_output (`torch.Tensor`):
534
535
536
537
538
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
539
            sample (`torch.Tensor`):
540
                A current instance of a sample created by the diffusion process.
541
542

        Returns:
543
            `torch.Tensor`:
544
                The sample tensor at the previous timestep.
545
        """
546
547
548
549
550
551
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
552
                raise ValueError("missing `sample` as a required keyword argument")
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s = torch.log(alpha_s) - torch.log(sigma_s)

573
574
575
576
577
578
579
580
581
        h = lambda_t - lambda_s
        if self.config.algorithm_type == "deis":
            x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
        else:
            raise NotImplementedError("only support log-rho multistep deis now")
        return x_t

    def multistep_deis_second_order_update(
        self,
582
        model_output_list: List[torch.Tensor],
583
        *args,
584
        sample: torch.Tensor = None,
585
        **kwargs,
586
    ) -> torch.Tensor:
587
588
589
590
        """
        One step for the second-order multistep DEIS.

        Args:
591
            model_output_list (`List[torch.Tensor]`):
592
                The direct outputs from learned diffusion model at current and latter timesteps.
593
            sample (`torch.Tensor`):
594
                A current instance of a sample created by the diffusion process.
595
596

        Returns:
597
            `torch.Tensor`:
598
                The sample tensor at the previous timestep.
599
        """
600
601
602
603
604
605
        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
606
                raise ValueError("missing `sample` as a required keyword argument")
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s0, sigma_s1 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
        )

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        m0, m1 = model_output_list[-1], model_output_list[-2]

        rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1

        if self.config.algorithm_type == "deis":

            def ind_fn(t, b, c):
                # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}]
                return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c))

            coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1)
            coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0)

            x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1)
            return x_t
        else:
            raise NotImplementedError("only support log-rho multistep deis now")

    def multistep_deis_third_order_update(
        self,
651
        model_output_list: List[torch.Tensor],
652
        *args,
653
        sample: torch.Tensor = None,
654
        **kwargs,
655
    ) -> torch.Tensor:
656
657
658
659
        """
        One step for the third-order multistep DEIS.

        Args:
660
            model_output_list (`List[torch.Tensor]`):
661
                The direct outputs from learned diffusion model at current and latter timesteps.
662
            sample (`torch.Tensor`):
663
                A current instance of a sample created by diffusion process.
664
665

        Returns:
666
            `torch.Tensor`:
667
                The sample tensor at the previous timestep.
668
        """
669
670
671
672
673
674
675

        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
676
                raise ValueError("missing `sample` as a required keyword argument")
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
            self.sigmas[self.step_index - 2],
        )

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
        alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)

703
        m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
704

705
706
707
708
        rho_t, rho_s0, rho_s1, rho_s2 = (
            sigma_t / alpha_t,
            sigma_s0 / alpha_s0,
            sigma_s1 / alpha_s1,
709
            sigma_s2 / alpha_s2,
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
        )

        if self.config.algorithm_type == "deis":

            def ind_fn(t, b, c, d):
                # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}]
                numerator = t * (
                    np.log(c) * (np.log(d) - np.log(t) + 1)
                    - np.log(d) * np.log(t)
                    + np.log(d)
                    + np.log(t) ** 2
                    - 2 * np.log(t)
                    + 2
                )
                denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d))
                return numerator / denominator

            coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2)
            coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0)
            coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1)

            x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2)

            return x_t
        else:
            raise NotImplementedError("only support log-rho multistep deis now")

737
738
739
740
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
741

742
        index_candidates = (schedule_timesteps == timestep).nonzero()
743
744
745
746
747
748
749
750
751
752
753
754

        if len(index_candidates) == 0:
            step_index = len(self.timesteps) - 1
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        elif len(index_candidates) > 1:
            step_index = index_candidates[1].item()
        else:
            step_index = index_candidates[0].item()

755
756
757
758
759
760
761
762
763
764
765
766
767
768
        return step_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
    def _init_step_index(self, timestep):
        """
        Initialize the step_index counter for the scheduler.
        """

        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
769

770
771
    def step(
        self,
772
        model_output: torch.Tensor,
773
        timestep: Union[int, torch.Tensor],
774
        sample: torch.Tensor,
775
776
777
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
778
779
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the multistep DEIS.
780
781

        Args:
782
            model_output (`torch.Tensor`):
783
                The direct output from learned diffusion model.
784
            timestep (`int`):
785
                The current discrete timestep in the diffusion chain.
786
            sample (`torch.Tensor`):
787
788
789
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
790
791

        Returns:
792
793
794
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
795
796
797
798
799
800
801

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

802
803
804
        if self.step_index is None:
            self._init_step_index(timestep)

805
        lower_order_final = (
806
            (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
807
808
        )
        lower_order_second = (
809
            (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
810
811
        )

812
        model_output = self.convert_model_output(model_output, sample=sample)
813
814
815
816
817
        for i in range(self.config.solver_order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
        self.model_outputs[-1] = model_output

        if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
818
            prev_sample = self.deis_first_order_update(model_output, sample=sample)
819
        elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
820
            prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
821
        else:
822
            prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
823
824
825
826

        if self.lower_order_nums < self.config.solver_order:
            self.lower_order_nums += 1

827
828
829
        # upon completion increase step index by one
        self._step_index += 1

830
831
832
833
834
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

835
    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
836
837
838
839
840
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
841
            sample (`torch.Tensor`):
842
                The input sample.
843
844

        Returns:
845
            `torch.Tensor`:
846
                A scaled input sample.
847
848
849
        """
        return sample

850
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
851
852
    def add_noise(
        self,
853
854
        original_samples: torch.Tensor,
        noise: torch.Tensor,
855
        timesteps: torch.IntTensor,
856
    ) -> torch.Tensor:
857
858
859
860
861
862
863
864
865
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
            schedule_timesteps = self.timesteps.to(original_samples.device)
            timesteps = timesteps.to(original_samples.device)
866

867
        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
868
869
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
870
871
872
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
873
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
874
            # add noise is called before first denoising step to create initial latent(img2img)
875
            step_indices = [self.begin_index] * timesteps.shape[0]
876

877
878
879
        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)
880

881
882
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
883
884
885
886
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps