scheduling_unipc_multistep.py 48.4 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Quentin Gallouédec's avatar
Quentin Gallouédec committed
15
# DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info
Wenliang Zhao's avatar
Wenliang Zhao committed
16
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
17
18

import math
19
from typing import List, Literal, Optional, Tuple, Union
20
21
22
23
24

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
25
from ..utils import deprecate, is_scipy_available
26
27
28
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


29
30
31
32
if is_scipy_available():
    import scipy.stats


33
34
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
35
36
37
38
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
39
40
41
42
43
44
45
46
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
47
48
49
50
51
52
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
53
54

    Returns:
55
56
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
57
    """
58
    if alpha_transform_type == "cosine":
59

60
61
62
63
64
65
66
67
68
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
69
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
70
71
72
73
74

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
75
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
76
77
78
    return torch.tensor(betas, dtype=torch.float32)


79
80
81
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
    """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
82
    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
83
84

    Args:
85
        betas (`torch.Tensor`):
86
            The betas that the scheduler is being initialized with.
87
88

    Returns:
89
90
        `torch.Tensor`:
            Rescaled betas with zero terminal SNR.
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


116
117
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
    """
118
    `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
119

120
121
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
122
123

    Args:
124
125
126
127
128
129
130
131
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
132
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
133
134
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
135
        solver_order (`int`, default `2`):
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
            due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
            unconditional sampling.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
        predict_x0 (`bool`, defaults to `True`):
            Whether to use the updating algorithm on the predicted x0.
152
        solver_type (`str`, default `bh2`):
153
            Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
154
155
            otherwise.
        lower_order_final (`bool`, default `True`):
156
157
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
158
        disable_corrector (`list`, default `[]`):
159
160
161
            Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
            and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
            usually disabled during the first few steps.
162
        solver_p (`SchedulerMixin`, default `None`):
163
            Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
164
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
165
166
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
167
168
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
169
170
171
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
172
173
        use_flow_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
174
175
176
177
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
178
            An offset added to the inference steps, as required by some model families.
179
        final_sigmas_type (`str`, defaults to `"zero"`):
180
181
            The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
            sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
182
183
184
185
        rescale_betas_zero_snr (`bool`, defaults to `False`):
            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
            dark samples instead of limiting it to samples with medium brightness. Loosely related to
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    """

    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
        solver_order: int = 2,
        prediction_type: str = "epsilon",
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        predict_x0: bool = True,
Wenliang Zhao's avatar
Wenliang Zhao committed
205
        solver_type: str = "bh2",
206
207
208
        lower_order_final: bool = True,
        disable_corrector: List[int] = [],
        solver_p: SchedulerMixin = None,
209
        use_karras_sigmas: Optional[bool] = False,
210
        use_exponential_sigmas: Optional[bool] = False,
211
        use_beta_sigmas: Optional[bool] = False,
212
213
        use_flow_sigmas: Optional[bool] = False,
        flow_shift: Optional[float] = 1.0,
214
215
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
216
        final_sigmas_type: Optional[str] = "zero",  # "zero", "sigma_min"
217
        rescale_betas_zero_snr: bool = False,
218
219
        use_dynamic_shifting: bool = False,
        time_shift_type: str = "exponential",
220
    ):
221
222
223
224
225
226
        if self.config.use_beta_sigmas and not is_scipy_available():
            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
227
228
229
230
231
232
        if trained_betas is not None:
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
233
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
234
235
236
237
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
238
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
239

240
241
242
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

243
244
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
245
246
247
248
249
250

        if rescale_betas_zero_snr:
            # Close to 0 without being 0 so first sigma is not inf
            # FP16 smallest positive subnormal works well here
            self.alphas_cumprod[-1] = 2**-24

251
252
253
254
        # Currently we only support VP-type noise schedule
        self.alpha_t = torch.sqrt(self.alphas_cumprod)
        self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
        self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
255
        self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
256
257
258
259
260
261

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        if solver_type not in ["bh1", "bh2"]:
            if solver_type in ["midpoint", "heun", "logrho"]:
262
                self.register_to_config(solver_type="bh2")
263
            else:
264
                raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
265
266
267
268
269
270
271
272
273
274
275
276

        self.predict_x0 = predict_x0
        # setable values
        self.num_inference_steps = None
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
        self.model_outputs = [None] * solver_order
        self.timestep_list = [None] * solver_order
        self.lower_order_nums = 0
        self.disable_corrector = disable_corrector
        self.solver_p = solver_p
        self.last_sample = None
277
        self._step_index = None
278
        self._begin_index = None
279
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
280
281
282
283

    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
284
        The index counter for current timestep. It will increase 1 after each scheduler step.
285
286
        """
        return self._step_index
287

288
289
290
291
292
293
294
295
296
297
298
299
300
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
301
            begin_index (`int`, defaults to `0`):
302
303
304
305
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

306
307
308
    def set_timesteps(
        self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
    ):
309
        """
310
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
311
312
313

        Args:
            num_inference_steps (`int`):
314
315
316
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
317
        """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
318
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
319
320
321
        if mu is not None:
            assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
            self.config.flow_shift = np.exp(mu)
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
                .round()[::-1][:-1]
                .copy()
                .astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
345

346
347
348
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        if self.config.use_karras_sigmas:
            log_sigmas = np.log(sigmas)
349
            sigmas = np.flip(sigmas).copy()
350
351
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
352
353
354
355
356
357
358
359
360
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
361
        elif self.config.use_exponential_sigmas:
362
363
364
            log_sigmas = np.log(sigmas)
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
365
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
366
367
368
369
370
371
372
373
374
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
375
        elif self.config.use_beta_sigmas:
376
377
378
            log_sigmas = np.log(sigmas)
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
379
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
380
381
382
383
384
385
386
387
388
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
389
390
391
        elif self.config.use_flow_sigmas:
            alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
            sigmas = 1.0 - alphas
hlky's avatar
hlky committed
392
            sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
393
            timesteps = (sigmas * self.config.num_train_timesteps).copy()
394
395
396
397
398
399
400
401
402
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
403
404
        else:
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
405
406
407
408
409
410
411
412
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
413
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
414

415
416
        self.sigmas = torch.from_numpy(sigmas)
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
417
418
419

        self.num_inference_steps = len(timesteps)

420
421
422
423
424
425
        self.model_outputs = [
            None,
        ] * self.config.solver_order
        self.lower_order_nums = 0
        self.last_sample = None
        if self.solver_p:
426
            self.solver_p.set_timesteps(self.num_inference_steps, device=device)
427

428
429
        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
430
        self._begin_index = None
431
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
432

433
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
434
    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
435
        """
436
437
        Apply dynamic thresholding to the predicted sample.

438
439
440
441
442
443
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

Quentin Gallouédec's avatar
Quentin Gallouédec committed
444
        https://huggingface.co/papers/2205.11487
445
446
447
448
449
450
451
452

        Args:
            sample (`torch.Tensor`):
                The predicted sample to be thresholded.

        Returns:
            `torch.Tensor`:
                The thresholded sample.
453
454
        """
        dtype = sample.dtype
455
        batch_size, channels, *remaining_dims = sample.shape
456
457
458
459
460

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
461
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
462
463
464
465
466
467
468
469
470
471

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

472
        sample = sample.reshape(batch_size, channels, *remaining_dims)
473
474
475
        sample = sample.to(dtype)

        return sample
476

477
478
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
479
480
481
482
483
484
485
486
487
488
489
490
491
        """
        Convert sigma values to corresponding timestep values through interpolation.

        Args:
            sigma (`np.ndarray`):
                The sigma value(s) to convert to timestep(s).
            log_sigmas (`np.ndarray`):
                The logarithm of the sigma schedule used for interpolation.

        Returns:
            `np.ndarray`:
                The interpolated timestep value(s) corresponding to the input sigma(s).
        """
492
        # get log sigma
493
        log_sigma = np.log(np.maximum(sigma, 1e-10))
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

514
515
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
    def _sigma_to_alpha_sigma_t(self, sigma):
516
517
518
519
520
521
        if self.config.use_flow_sigmas:
            alpha_t = 1 - sigma
            sigma_t = sigma
        else:
            alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
            sigma_t = sigma * alpha_t
522
523
524

        return alpha_t, sigma_t

525
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
526
    def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
527
528
529
530
531
532
533
534
535
536
537
538
539
540
        """
        Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
        Models](https://huggingface.co/papers/2206.00364).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following the Karras noise schedule.
        """
541

Suraj Patil's avatar
Suraj Patil committed
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
556
557
558
559
560
561
562
563

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

564
565
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
566
567
568
569
570
571
572
573
574
575
576
577
578
        """
        Construct an exponential noise schedule.

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following an exponential schedule.
        """
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

595
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
596
597
        return sigmas

598
599
600
601
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        """
        Construct a beta noise schedule as proposed in [Beta Sampling is All You
        Need](https://huggingface.co/papers/2407.12173).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.
            alpha (`float`, *optional*, defaults to `0.6`):
                The alpha parameter for the beta distribution.
            beta (`float`, *optional*, defaults to `0.6`):
                The beta parameter for the beta distribution.

        Returns:
            `torch.Tensor`:
                The converted sigma values following a beta distribution schedule.
        """
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

636
        sigmas = np.array(
637
638
639
640
641
642
643
644
645
646
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

647
    def convert_model_output(
648
        self,
649
        model_output: torch.Tensor,
650
        *args,
651
        sample: torch.Tensor = None,
652
        **kwargs,
653
    ) -> torch.Tensor:
654
        r"""
655
        Convert the model output to the corresponding type the UniPC algorithm needs.
656
657

        Args:
658
            model_output (`torch.Tensor`):
659
660
661
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
662
            sample (`torch.Tensor`):
663
                A current instance of a sample created by the diffusion process.
664
665

        Returns:
666
            `torch.Tensor`:
667
                The converted model output.
668
        """
669
670
671
672
673
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
674
                raise ValueError("missing `sample` as a required keyword argument")
675
676
677
678
679
680
681
682
683
684
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma = self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)

685
686
687
688
689
690
691
        if self.predict_x0:
            if self.config.prediction_type == "epsilon":
                x0_pred = (sample - sigma_t * model_output) / alpha_t
            elif self.config.prediction_type == "sample":
                x0_pred = model_output
            elif self.config.prediction_type == "v_prediction":
                x0_pred = alpha_t * sample - sigma_t * model_output
692
693
694
            elif self.config.prediction_type == "flow_prediction":
                sigma_t = self.sigmas[self.step_index]
                x0_pred = sample - sigma_t * model_output
695
696
            else:
                raise ValueError(
697
698
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
                    "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
699
700
701
                )

            if self.config.thresholding:
702
703
                x0_pred = self._threshold_sample(x0_pred)

704
705
706
707
708
709
710
711
712
713
714
715
716
            return x0_pred
        else:
            if self.config.prediction_type == "epsilon":
                return model_output
            elif self.config.prediction_type == "sample":
                epsilon = (sample - alpha_t * model_output) / sigma_t
                return epsilon
            elif self.config.prediction_type == "v_prediction":
                epsilon = alpha_t * model_output + sigma_t * sample
                return epsilon
            else:
                raise ValueError(
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
Wenliang Zhao's avatar
Wenliang Zhao committed
717
                    " `v_prediction` for the UniPCMultistepScheduler."
718
719
720
721
                )

    def multistep_uni_p_bh_update(
        self,
722
        model_output: torch.Tensor,
723
        *args,
724
        sample: torch.Tensor = None,
725
726
        order: int = None,
        **kwargs,
727
    ) -> torch.Tensor:
728
729
730
731
        """
        One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.

        Args:
732
            model_output (`torch.Tensor`):
733
734
735
                The direct output from the learned diffusion model at the current timestep.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
736
            sample (`torch.Tensor`):
737
738
739
                A current instance of a sample created by the diffusion process.
            order (`int`):
                The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
740
741

        Returns:
742
            `torch.Tensor`:
743
                The sample tensor at the previous timestep.
744
        """
745
746
747
748
749
        prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
750
                raise ValueError("missing `sample` as a required keyword argument")
751
752
753
754
        if order is None:
            if len(args) > 2:
                order = args[2]
            else:
755
                raise ValueError("missing `order` as a required keyword argument")
756
757
758
759
760
761
        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )
762
763
        model_output_list = self.model_outputs

764
        s0 = self.timestep_list[-1]
765
766
767
768
769
770
771
        m0 = model_output_list[-1]
        x = sample

        if self.solver_p:
            x_t = self.solver_p.step(model_output, s0, x).prev_sample
            return x_t

772
773
774
775
776
777
        sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
778
779
780
781
782
783
784

        h = lambda_t - lambda_s0
        device = sample.device

        rks = []
        D1s = []
        for i in range(1, order):
785
            si = self.step_index - i
786
            mi = model_output_list[-(i + 1)]
787
788
            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)

        rks.append(1.0)
        rks = torch.tensor(rks, device=device)

        R = []
        b = []

        hh = -h if self.predict_x0 else h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1

        factorial_i = 1

        if self.config.solver_type == "bh1":
            B_h = hh
        elif self.config.solver_type == "bh2":
            B_h = torch.expm1(hh)
        else:
            raise NotImplementedError()

        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i

        R = torch.stack(R)
        b = torch.tensor(b, device=device)

        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)  # (B, K)
            # for order 2, we use a simplified version
            if order == 2:
                rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
            else:
827
                rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
828
829
830
831
832
833
        else:
            D1s = None

        if self.predict_x0:
            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
            if D1s is not None:
834
                pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
835
836
837
838
839
840
            else:
                pred_res = 0
            x_t = x_t_ - alpha_t * B_h * pred_res
        else:
            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
            if D1s is not None:
841
                pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
842
843
844
845
846
847
848
849
850
            else:
                pred_res = 0
            x_t = x_t_ - sigma_t * B_h * pred_res

        x_t = x_t.to(x.dtype)
        return x_t

    def multistep_uni_c_bh_update(
        self,
851
        this_model_output: torch.Tensor,
852
        *args,
853
854
        last_sample: torch.Tensor = None,
        this_sample: torch.Tensor = None,
855
856
        order: int = None,
        **kwargs,
857
    ) -> torch.Tensor:
858
859
860
861
        """
        One step for the UniC (B(h) version).

        Args:
862
            this_model_output (`torch.Tensor`):
863
864
865
                The model outputs at `x_t`.
            this_timestep (`int`):
                The current timestep `t`.
866
            last_sample (`torch.Tensor`):
867
                The generated sample before the last predictor `x_{t-1}`.
868
            this_sample (`torch.Tensor`):
869
870
871
                The generated sample after the last predictor `x_{t}`.
            order (`int`):
                The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
872
873

        Returns:
874
            `torch.Tensor`:
875
                The corrected sample tensor at the current timestep.
876
        """
877
878
879
880
881
        this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
        if last_sample is None:
            if len(args) > 1:
                last_sample = args[1]
            else:
882
                raise ValueError("missing `last_sample` as a required keyword argument")
883
884
885
886
        if this_sample is None:
            if len(args) > 2:
                this_sample = args[2]
            else:
887
                raise ValueError("missing `this_sample` as a required keyword argument")
888
889
890
891
        if order is None:
            if len(args) > 3:
                order = args[3]
            else:
892
                raise ValueError("missing `order` as a required keyword argument")
893
894
895
896
897
898
899
        if this_timestep is not None:
            deprecate(
                "this_timestep",
                "1.0.0",
                "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

900
901
902
903
904
905
906
        model_output_list = self.model_outputs

        m0 = model_output_list[-1]
        x = last_sample
        x_t = this_sample
        model_t = this_model_output

907
908
909
910
911
912
        sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
913
914
915
916
917
918
919

        h = lambda_t - lambda_s0
        device = this_sample.device

        rks = []
        D1s = []
        for i in range(1, order):
920
            si = self.step_index - (i + 1)
921
            mi = model_output_list[-(i + 1)]
922
923
            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)

        rks.append(1.0)
        rks = torch.tensor(rks, device=device)

        R = []
        b = []

        hh = -h if self.predict_x0 else h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1

        factorial_i = 1

        if self.config.solver_type == "bh1":
            B_h = hh
        elif self.config.solver_type == "bh2":
            B_h = torch.expm1(hh)
        else:
            raise NotImplementedError()

        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i

        R = torch.stack(R)
        b = torch.tensor(b, device=device)

        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)
        else:
            D1s = None

        # for order 1, we use a simplified version
        if order == 1:
            rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
        else:
965
            rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
966
967
968
969

        if self.predict_x0:
            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
            if D1s is not None:
970
                corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
971
972
973
974
975
976
977
            else:
                corr_res = 0
            D1_t = model_t - m0
            x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
        else:
            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
            if D1s is not None:
978
                corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
979
980
981
982
983
984
985
            else:
                corr_res = 0
            D1_t = model_t - m0
            x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
        x_t = x_t.to(x.dtype)
        return x_t

986
987
988
989
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
990

991
        index_candidates = (schedule_timesteps == timestep).nonzero()
992
993
994
995
996
997
998
999
1000
1001
1002
1003

        if len(index_candidates) == 0:
            step_index = len(self.timesteps) - 1
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        elif len(index_candidates) > 1:
            step_index = index_candidates[1].item()
        else:
            step_index = index_candidates[0].item()

1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
        return step_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
    def _init_step_index(self, timestep):
        """
        Initialize the step_index counter for the scheduler.
        """

        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
1018

1019
1020
    def step(
        self,
1021
        model_output: torch.Tensor,
1022
        timestep: Union[int, torch.Tensor],
1023
        sample: torch.Tensor,
1024
1025
1026
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
1027
1028
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the multistep UniPC.
1029
1030

        Args:
1031
            model_output (`torch.Tensor`):
1032
1033
1034
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
1035
            sample (`torch.Tensor`):
1036
1037
1038
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
1039
1040

        Returns:
1041
1042
1043
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
1044
1045
1046
1047
1048
1049
1050

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

1051
1052
        if self.step_index is None:
            self._init_step_index(timestep)
1053
1054

        use_corrector = (
1055
            self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
1056
1057
        )

1058
        model_output_convert = self.convert_model_output(model_output, sample=sample)
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        if use_corrector:
            sample = self.multistep_uni_c_bh_update(
                this_model_output=model_output_convert,
                last_sample=self.last_sample,
                this_sample=sample,
                order=self.this_order,
            )

        for i in range(self.config.solver_order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
            self.timestep_list[i] = self.timestep_list[i + 1]

        self.model_outputs[-1] = model_output_convert
        self.timestep_list[-1] = timestep

        if self.config.lower_order_final:
1075
            this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
        else:
            this_order = self.config.solver_order

        self.this_order = min(this_order, self.lower_order_nums + 1)  # warmup for multistep
        assert self.this_order > 0

        self.last_sample = sample
        prev_sample = self.multistep_uni_p_bh_update(
            model_output=model_output,  # pass the original non-converted model output, in case solver-p is used
            sample=sample,
            order=self.this_order,
        )

        if self.lower_order_nums < self.config.solver_order:
            self.lower_order_nums += 1

1092
1093
1094
        # upon completion increase step index by one
        self._step_index += 1

1095
1096
1097
1098
1099
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

1100
    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1101
1102
1103
1104
1105
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
1106
            sample (`torch.Tensor`):
1107
                The input sample.
1108
1109

        Returns:
1110
            `torch.Tensor`:
1111
                A scaled input sample.
1112
1113
1114
        """
        return sample

1115
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
1116
1117
    def add_noise(
        self,
1118
1119
        original_samples: torch.Tensor,
        noise: torch.Tensor,
1120
        timesteps: torch.IntTensor,
1121
    ) -> torch.Tensor:
1122
1123
1124
1125
1126
1127
1128
1129
1130
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
            schedule_timesteps = self.timesteps.to(original_samples.device)
            timesteps = timesteps.to(original_samples.device)
1131

1132
        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
1133
1134
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1135
1136
1137
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
1138
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
1139
            # add noise is called before first denoising step to create initial latent(img2img)
1140
            step_indices = [self.begin_index] * timesteps.shape[0]
1141

1142
1143
1144
        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)
1145

1146
1147
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
1148
1149
1150
1151
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps