scheduling_deis_multistep.py 39.4 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 FLAIR Lab and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Quentin Gallouédec's avatar
Quentin Gallouédec committed
15
# DISCLAIMER: check https://huggingface.co/papers/2204.13902 and https://github.com/qsh-zh/deis for more info
16
17
18
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

import math
19
from typing import List, Literal, Optional, Tuple, Union
20
21
22
23
24

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
25
from ..utils import deprecate, is_scipy_available
Kashif Rasul's avatar
Kashif Rasul committed
26
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
28


29
30
31
32
if is_scipy_available():
    import scipy.stats


33
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
34
def betas_for_alpha_bar(
35
36
37
38
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
39
40
41
42
43
44
45
46
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
47
48
49
50
51
52
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
53
54

    Returns:
55
56
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
57
    """
YiYi Xu's avatar
YiYi Xu committed
58
    if alpha_transform_type == "cosine":
59

YiYi Xu's avatar
YiYi Xu committed
60
61
62
63
64
65
66
67
68
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
69
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
70
71
72
73
74

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
75
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
76
77
78
79
80
    return torch.tensor(betas, dtype=torch.float32)


class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
    """
81
    `DEISMultistepScheduler` is a fast high order solver for diffusion ordinary differential equations (ODEs).
82

83
84
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
85
86

    Args:
87
88
89
90
91
92
93
94
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
95
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        solver_order (`int`, defaults to 2):
            The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
            sampling, and `solver_order=3` for unconditional sampling.
        prediction_type (`str`, defaults to `epsilon`):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
        algorithm_type (`str`, defaults to `deis`):
            The algorithm type for the solver.
        lower_order_final (`bool`, defaults to `True`):
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps.
116
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
117
118
             Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
             the sigmas are determined according to a sequence of noise levels {σi}.
119
120
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
121
122
123
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
124
125
126
127
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
128
            An offset added to the inference steps, as required by some model families.
129
130
    """

Kashif Rasul's avatar
Kashif Rasul committed
131
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
        solver_order: int = 2,
        prediction_type: str = "epsilon",
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        algorithm_type: str = "deis",
        solver_type: str = "logrho",
        lower_order_final: bool = True,
150
        use_karras_sigmas: Optional[bool] = False,
151
        use_exponential_sigmas: Optional[bool] = False,
152
        use_beta_sigmas: Optional[bool] = False,
153
154
        use_flow_sigmas: Optional[bool] = False,
        flow_shift: Optional[float] = 1.0,
155
156
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
157
158
        use_dynamic_shifting: bool = False,
        time_shift_type: str = "exponential",
159
    ):
160
161
162
163
164
165
        if self.config.use_beta_sigmas and not is_scipy_available():
            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
166
167
168
169
170
171
        if trained_betas is not None:
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
172
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
173
174
175
176
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
177
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
178
179
180
181
182
183
184

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        # Currently we only support VP-type noise schedule
        self.alpha_t = torch.sqrt(self.alphas_cumprod)
        self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
        self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
185
        self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
186
187
188
189
190
191
192

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # settings for DEIS
        if algorithm_type not in ["deis"]:
            if algorithm_type in ["dpmsolver", "dpmsolver++"]:
193
                self.register_to_config(algorithm_type="deis")
194
            else:
195
                raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
196
197

        if solver_type not in ["logrho"]:
198
            if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
199
                self.register_to_config(solver_type="logrho")
200
            else:
201
                raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}")
202
203
204
205
206
207
208

        # setable values
        self.num_inference_steps = None
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
        self.model_outputs = [None] * solver_order
        self.lower_order_nums = 0
209
        self._step_index = None
210
        self._begin_index = None
211
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
212
213
214
215

    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
216
        The index counter for current timestep. It will increase 1 after each scheduler step.
217
218
        """
        return self._step_index
219

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

238
239
240
    def set_timesteps(
        self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
    ):
241
        """
242
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
243
244
245

        Args:
            num_inference_steps (`int`):
246
247
248
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
249
        """
250
251
252
        if mu is not None:
            assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
            self.config.flow_shift = np.exp(mu)
Quentin Gallouédec's avatar
Quentin Gallouédec committed
253
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
                .round()[::-1][:-1]
                .copy()
                .astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
277

278
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
279
        log_sigmas = np.log(sigmas)
280
        if self.config.use_karras_sigmas:
281
            sigmas = np.flip(sigmas).copy()
282
283
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
284
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
285
        elif self.config.use_exponential_sigmas:
286
287
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
288
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
289
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
290
        elif self.config.use_beta_sigmas:
291
292
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
293
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
294
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
295
296
297
        elif self.config.use_flow_sigmas:
            alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
            sigmas = 1.0 - alphas
hlky's avatar
hlky committed
298
            sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
299
            timesteps = (sigmas * self.config.num_train_timesteps).copy()
300
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
301
302
303
304
        else:
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
            sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
305

306
307
        self.sigmas = torch.from_numpy(sigmas)
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
308
309
310

        self.num_inference_steps = len(timesteps)

311
312
313
314
315
        self.model_outputs = [
            None,
        ] * self.config.solver_order
        self.lower_order_nums = 0

316
317
        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
318
        self._begin_index = None
319
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
320

321
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
322
    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
323
        """
324
325
        Apply dynamic thresholding to the predicted sample.

326
327
328
329
330
331
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

Quentin Gallouédec's avatar
Quentin Gallouédec committed
332
        https://huggingface.co/papers/2205.11487
333
334
335
336
337
338
339
340

        Args:
            sample (`torch.Tensor`):
                The predicted sample to be thresholded.

        Returns:
            `torch.Tensor`:
                The thresholded sample.
341
342
        """
        dtype = sample.dtype
343
        batch_size, channels, *remaining_dims = sample.shape
344
345
346
347
348

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
349
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
350
351
352
353
354
355
356
357
358
359

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

360
        sample = sample.reshape(batch_size, channels, *remaining_dims)
361
362
363
        sample = sample.to(dtype)

        return sample
364

365
366
367
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
368
        log_sigma = np.log(np.maximum(sigma, 1e-10))
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
    def _sigma_to_alpha_sigma_t(self, sigma):
391
392
393
394
395
396
        if self.config.use_flow_sigmas:
            alpha_t = 1 - sigma
            sigma_t = sigma
        else:
            alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
            sigma_t = sigma * alpha_t
397
398
399
400

        return alpha_t, sigma_t

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
401
    def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
402
403
        """Constructs the noise schedule of Karras et al. (2022)."""

Suraj Patil's avatar
Suraj Patil committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
418
419
420
421
422
423
424
425

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
        """Constructs an exponential noise schedule."""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

445
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
446
447
        return sigmas

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
        """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

469
        sigmas = np.array(
470
471
472
473
474
475
476
477
478
479
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

480
    def convert_model_output(
481
        self,
482
        model_output: torch.Tensor,
483
        *args,
484
        sample: torch.Tensor = None,
485
        **kwargs,
486
    ) -> torch.Tensor:
487
        """
488
        Convert the model output to the corresponding type the DEIS algorithm needs.
489
490

        Args:
491
            model_output (`torch.Tensor`):
492
493
494
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
495
            sample (`torch.Tensor`):
496
                A current instance of a sample created by the diffusion process.
497
498

        Returns:
499
            `torch.Tensor`:
500
                The converted model output.
501
        """
502
503
504
505
506
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
507
                raise ValueError("missing `sample` as a required keyword argument")
508
509
510
511
512
513
514
515
516
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma = self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
517
518
519
520
521
522
        if self.config.prediction_type == "epsilon":
            x0_pred = (sample - sigma_t * model_output) / alpha_t
        elif self.config.prediction_type == "sample":
            x0_pred = model_output
        elif self.config.prediction_type == "v_prediction":
            x0_pred = alpha_t * sample - sigma_t * model_output
523
524
525
        elif self.config.prediction_type == "flow_prediction":
            sigma_t = self.sigmas[self.step_index]
            x0_pred = sample - sigma_t * model_output
526
527
        else:
            raise ValueError(
528
529
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
                "`v_prediction`, or `flow_prediction` for the DEISMultistepScheduler."
530
531
532
            )

        if self.config.thresholding:
533
            x0_pred = self._threshold_sample(x0_pred)
534
535
536
537
538
539
540
541

        if self.config.algorithm_type == "deis":
            return (sample - alpha_t * x0_pred) / sigma_t
        else:
            raise NotImplementedError("only support log-rho multistep deis now")

    def deis_first_order_update(
        self,
542
        model_output: torch.Tensor,
543
        *args,
544
        sample: torch.Tensor = None,
545
        **kwargs,
546
    ) -> torch.Tensor:
547
548
549
550
        """
        One step for the first-order DEIS (equivalent to DDIM).

        Args:
551
            model_output (`torch.Tensor`):
552
553
554
555
556
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
557
            sample (`torch.Tensor`):
558
                A current instance of a sample created by the diffusion process.
559
560

        Returns:
561
            `torch.Tensor`:
562
                The sample tensor at the previous timestep.
563
        """
564
565
566
567
568
569
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
570
                raise ValueError("missing `sample` as a required keyword argument")
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s = torch.log(alpha_s) - torch.log(sigma_s)

591
592
593
594
595
596
597
598
599
        h = lambda_t - lambda_s
        if self.config.algorithm_type == "deis":
            x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
        else:
            raise NotImplementedError("only support log-rho multistep deis now")
        return x_t

    def multistep_deis_second_order_update(
        self,
600
        model_output_list: List[torch.Tensor],
601
        *args,
602
        sample: torch.Tensor = None,
603
        **kwargs,
604
    ) -> torch.Tensor:
605
606
607
608
        """
        One step for the second-order multistep DEIS.

        Args:
609
            model_output_list (`List[torch.Tensor]`):
610
                The direct outputs from learned diffusion model at current and latter timesteps.
611
            sample (`torch.Tensor`):
612
                A current instance of a sample created by the diffusion process.
613
614

        Returns:
615
            `torch.Tensor`:
616
                The sample tensor at the previous timestep.
617
        """
618
619
620
621
622
623
        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
624
                raise ValueError("missing `sample` as a required keyword argument")
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s0, sigma_s1 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
        )

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)

649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
        m0, m1 = model_output_list[-1], model_output_list[-2]

        rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1

        if self.config.algorithm_type == "deis":

            def ind_fn(t, b, c):
                # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}]
                return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c))

            coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1)
            coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0)

            x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1)
            return x_t
        else:
            raise NotImplementedError("only support log-rho multistep deis now")

    def multistep_deis_third_order_update(
        self,
669
        model_output_list: List[torch.Tensor],
670
        *args,
671
        sample: torch.Tensor = None,
672
        **kwargs,
673
    ) -> torch.Tensor:
674
675
676
677
        """
        One step for the third-order multistep DEIS.

        Args:
678
            model_output_list (`List[torch.Tensor]`):
679
                The direct outputs from learned diffusion model at current and latter timesteps.
680
            sample (`torch.Tensor`):
681
                A current instance of a sample created by diffusion process.
682
683

        Returns:
684
            `torch.Tensor`:
685
                The sample tensor at the previous timestep.
686
        """
687
688
689
690
691
692
693

        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
694
                raise ValueError("missing `sample` as a required keyword argument")
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
            self.sigmas[self.step_index - 2],
        )

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
        alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)

721
        m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
722

723
724
725
726
        rho_t, rho_s0, rho_s1, rho_s2 = (
            sigma_t / alpha_t,
            sigma_s0 / alpha_s0,
            sigma_s1 / alpha_s1,
727
            sigma_s2 / alpha_s2,
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
        )

        if self.config.algorithm_type == "deis":

            def ind_fn(t, b, c, d):
                # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}]
                numerator = t * (
                    np.log(c) * (np.log(d) - np.log(t) + 1)
                    - np.log(d) * np.log(t)
                    + np.log(d)
                    + np.log(t) ** 2
                    - 2 * np.log(t)
                    + 2
                )
                denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d))
                return numerator / denominator

            coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2)
            coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0)
            coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1)

            x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2)

            return x_t
        else:
            raise NotImplementedError("only support log-rho multistep deis now")

755
756
757
758
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
759

760
        index_candidates = (schedule_timesteps == timestep).nonzero()
761
762
763
764
765
766
767
768
769
770
771
772

        if len(index_candidates) == 0:
            step_index = len(self.timesteps) - 1
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        elif len(index_candidates) > 1:
            step_index = index_candidates[1].item()
        else:
            step_index = index_candidates[0].item()

773
774
775
776
777
778
779
780
781
782
783
784
785
786
        return step_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
    def _init_step_index(self, timestep):
        """
        Initialize the step_index counter for the scheduler.
        """

        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
787

788
789
    def step(
        self,
790
        model_output: torch.Tensor,
791
        timestep: Union[int, torch.Tensor],
792
        sample: torch.Tensor,
793
794
795
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
796
797
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the multistep DEIS.
798
799

        Args:
800
            model_output (`torch.Tensor`):
801
                The direct output from learned diffusion model.
802
            timestep (`int`):
803
                The current discrete timestep in the diffusion chain.
804
            sample (`torch.Tensor`):
805
806
807
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
808
809

        Returns:
810
811
812
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
813
814
815
816
817
818
819

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

820
821
822
        if self.step_index is None:
            self._init_step_index(timestep)

823
        lower_order_final = (
824
            (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
825
826
        )
        lower_order_second = (
827
            (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
828
829
        )

830
        model_output = self.convert_model_output(model_output, sample=sample)
831
832
833
834
835
        for i in range(self.config.solver_order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
        self.model_outputs[-1] = model_output

        if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
836
            prev_sample = self.deis_first_order_update(model_output, sample=sample)
837
        elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
838
            prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
839
        else:
840
            prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
841
842
843
844

        if self.lower_order_nums < self.config.solver_order:
            self.lower_order_nums += 1

845
846
847
        # upon completion increase step index by one
        self._step_index += 1

848
849
850
851
852
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

853
    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
854
855
856
857
858
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
859
            sample (`torch.Tensor`):
860
                The input sample.
861
862

        Returns:
863
            `torch.Tensor`:
864
                A scaled input sample.
865
866
867
        """
        return sample

868
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
869
870
    def add_noise(
        self,
871
872
        original_samples: torch.Tensor,
        noise: torch.Tensor,
873
        timesteps: torch.IntTensor,
874
    ) -> torch.Tensor:
875
876
877
878
879
880
881
882
883
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
            schedule_timesteps = self.timesteps.to(original_samples.device)
            timesteps = timesteps.to(original_samples.device)
884

885
        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
886
887
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
888
889
890
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
891
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
892
            # add noise is called before first denoising step to create initial latent(img2img)
893
            step_indices = [self.begin_index] * timesteps.shape[0]
894

895
896
897
        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)
898

899
900
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
901
902
903
904
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps