scheduling_unipc_multistep.py 50.6 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Quentin Gallouédec's avatar
Quentin Gallouédec committed
15
# DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info
Wenliang Zhao's avatar
Wenliang Zhao committed
16
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
17
18

import math
19
from typing import List, Literal, Optional, Tuple, Union
20
21
22
23
24

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
25
from ..utils import deprecate, is_scipy_available
26
27
28
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


29
30
31
32
if is_scipy_available():
    import scipy.stats


33
34
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
35
36
37
38
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
39
40
41
42
43
44
45
46
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
47
48
49
50
51
52
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
53
54

    Returns:
55
56
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
57
    """
58
    if alpha_transform_type == "cosine":
59

60
61
62
63
64
65
66
67
68
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
69
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
70
71
72
73
74

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
75
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
76
77
78
    return torch.tensor(betas, dtype=torch.float32)


79
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
80
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
81
    """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
82
    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
83
84

    Args:
85
        betas (`torch.Tensor`):
86
            The betas that the scheduler is being initialized with.
87
88

    Returns:
89
90
        `torch.Tensor`:
            Rescaled betas with zero terminal SNR.
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


116
117
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
    """
118
    `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
119

120
121
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
122
123

    Args:
124
125
126
127
128
129
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
130
        beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
131
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
132
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
133
134
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
135
        solver_order (`int`, defaults to `2`):
136
137
138
            The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
            due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
            unconditional sampling.
139
        prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`, *optional*):
140
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
141
142
            `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
            Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
143
144
145
146
147
148
149
150
151
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
        predict_x0 (`bool`, defaults to `True`):
            Whether to use the updating algorithm on the predicted x0.
152
        solver_type (`"bh1"` or `"bh2"`, defaults to `"bh2"`):
153
            Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
154
155
            otherwise.
        lower_order_final (`bool`, default `True`):
156
157
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
158
        disable_corrector (`list`, default `[]`):
159
160
161
            Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
            and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
            usually disabled during the first few steps.
162
        solver_p (`SchedulerMixin`, default `None`):
163
            Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
164
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
165
166
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
167
168
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
169
170
171
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
172
173
        use_flow_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
174
        timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
175
176
177
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
178
            An offset added to the inference steps, as required by some model families.
179
        final_sigmas_type (`"zero"` or `"sigma_min"`, defaults to `"zero"`):
180
181
            The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
            sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
182
183
184
185
        rescale_betas_zero_snr (`bool`, defaults to `False`):
            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
            dark samples instead of limiting it to samples with medium brightness. Loosely related to
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
186
187
188
189
190
191
192
193
194
195
196
    """

    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
197
        beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
198
199
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
        solver_order: int = 2,
200
        prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
201
202
203
204
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        predict_x0: bool = True,
205
        solver_type: Literal["bh1", "bh2"] = "bh2",
206
207
        lower_order_final: bool = True,
        disable_corrector: List[int] = [],
208
        solver_p: Optional[SchedulerMixin] = None,
209
        use_karras_sigmas: Optional[bool] = False,
210
        use_exponential_sigmas: Optional[bool] = False,
211
        use_beta_sigmas: Optional[bool] = False,
212
213
        use_flow_sigmas: Optional[bool] = False,
        flow_shift: Optional[float] = 1.0,
214
        timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
215
        steps_offset: int = 0,
216
        final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
217
        rescale_betas_zero_snr: bool = False,
218
        use_dynamic_shifting: bool = False,
219
220
        time_shift_type: Literal["exponential"] = "exponential",
    ) -> None:
221
222
223
224
225
226
        if self.config.use_beta_sigmas and not is_scipy_available():
            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
227
228
229
230
231
232
        if trained_betas is not None:
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
233
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
234
235
236
237
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
238
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
239

240
241
242
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

243
244
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
245
246
247
248
249
250

        if rescale_betas_zero_snr:
            # Close to 0 without being 0 so first sigma is not inf
            # FP16 smallest positive subnormal works well here
            self.alphas_cumprod[-1] = 2**-24

251
252
253
254
        # Currently we only support VP-type noise schedule
        self.alpha_t = torch.sqrt(self.alphas_cumprod)
        self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
        self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
255
        self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
256
257
258
259
260
261

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        if solver_type not in ["bh1", "bh2"]:
            if solver_type in ["midpoint", "heun", "logrho"]:
262
                self.register_to_config(solver_type="bh2")
263
            else:
264
                raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
265
266
267
268
269
270
271
272
273
274
275
276

        self.predict_x0 = predict_x0
        # setable values
        self.num_inference_steps = None
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
        self.model_outputs = [None] * solver_order
        self.timestep_list = [None] * solver_order
        self.lower_order_nums = 0
        self.disable_corrector = disable_corrector
        self.solver_p = solver_p
        self.last_sample = None
277
        self._step_index = None
278
        self._begin_index = None
279
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
280
281

    @property
282
    def step_index(self) -> Optional[int]:
283
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
284
        The index counter for current timestep. It will increase 1 after each scheduler step.
285
286
        """
        return self._step_index
287

288
    @property
289
    def begin_index(self) -> Optional[int]:
290
291
292
293
294
295
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
296
    def set_begin_index(self, begin_index: int = 0) -> None:
297
298
299
300
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
301
            begin_index (`int`, defaults to `0`):
302
303
304
305
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

306
    def set_timesteps(
307
308
        self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
    ) -> None:
309
        """
310
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
311
312
313

        Args:
            num_inference_steps (`int`):
314
315
316
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
317
318
            mu (`float`, *optional*):
                Optional mu parameter for dynamic shifting when using exponential time shift type.
319
        """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
320
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
321
322
323
        if mu is not None:
            assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
            self.config.flow_shift = np.exp(mu)
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
                .round()[::-1][:-1]
                .copy()
                .astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
347

348
349
350
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        if self.config.use_karras_sigmas:
            log_sigmas = np.log(sigmas)
351
            sigmas = np.flip(sigmas).copy()
352
353
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
354
355
356
357
358
359
360
361
362
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
363
        elif self.config.use_exponential_sigmas:
364
365
366
            log_sigmas = np.log(sigmas)
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
367
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
368
369
370
371
372
373
374
375
376
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
377
        elif self.config.use_beta_sigmas:
378
379
380
            log_sigmas = np.log(sigmas)
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
381
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
382
383
384
385
386
387
388
389
390
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
391
392
393
        elif self.config.use_flow_sigmas:
            alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
            sigmas = 1.0 - alphas
hlky's avatar
hlky committed
394
            sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
395
            timesteps = (sigmas * self.config.num_train_timesteps).copy()
396
397
398
399
400
401
402
403
404
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
405
406
        else:
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
407
408
409
410
411
412
413
414
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
415
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
416

417
418
        self.sigmas = torch.from_numpy(sigmas)
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
419
420
421

        self.num_inference_steps = len(timesteps)

422
423
424
425
426
427
        self.model_outputs = [
            None,
        ] * self.config.solver_order
        self.lower_order_nums = 0
        self.last_sample = None
        if self.solver_p:
428
            self.solver_p.set_timesteps(self.num_inference_steps, device=device)
429

430
431
        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
432
        self._begin_index = None
433
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
434

435
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
436
    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
437
        """
438
439
        Apply dynamic thresholding to the predicted sample.

440
441
442
443
444
445
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

Quentin Gallouédec's avatar
Quentin Gallouédec committed
446
        https://huggingface.co/papers/2205.11487
447
448
449
450
451
452
453
454

        Args:
            sample (`torch.Tensor`):
                The predicted sample to be thresholded.

        Returns:
            `torch.Tensor`:
                The thresholded sample.
455
456
        """
        dtype = sample.dtype
457
        batch_size, channels, *remaining_dims = sample.shape
458
459
460
461
462

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
463
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
464
465
466
467
468
469
470
471
472
473

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

474
        sample = sample.reshape(batch_size, channels, *remaining_dims)
475
476
477
        sample = sample.to(dtype)

        return sample
478

479
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
480
    def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
481
482
483
484
485
486
487
488
489
490
491
492
493
        """
        Convert sigma values to corresponding timestep values through interpolation.

        Args:
            sigma (`np.ndarray`):
                The sigma value(s) to convert to timestep(s).
            log_sigmas (`np.ndarray`):
                The logarithm of the sigma schedule used for interpolation.

        Returns:
            `np.ndarray`:
                The interpolated timestep value(s) corresponding to the input sigma(s).
        """
494
        # get log sigma
495
        log_sigma = np.log(np.maximum(sigma, 1e-10))
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

516
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
517
    def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
518
519
520
521
522
523
524
525
526
527
528
        """
        Convert sigma values to alpha_t and sigma_t values.

        Args:
            sigma (`torch.Tensor`):
                The sigma value(s) to convert.

        Returns:
            `Tuple[torch.Tensor, torch.Tensor]`:
                A tuple containing (alpha_t, sigma_t) values.
        """
529
530
531
532
533
534
        if self.config.use_flow_sigmas:
            alpha_t = 1 - sigma
            sigma_t = sigma
        else:
            alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
            sigma_t = sigma * alpha_t
535
536
537

        return alpha_t, sigma_t

538
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
539
    def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        """
        Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
        Models](https://huggingface.co/papers/2206.00364).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following the Karras noise schedule.
        """
554

Suraj Patil's avatar
Suraj Patil committed
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
569
570
571
572
573
574
575
576

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

577
578
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
579
580
581
582
583
584
585
586
587
588
589
590
591
        """
        Construct an exponential noise schedule.

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following an exponential schedule.
        """
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

608
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
609
610
        return sigmas

611
612
613
614
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        """
        Construct a beta noise schedule as proposed in [Beta Sampling is All You
        Need](https://huggingface.co/papers/2407.12173).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.
            alpha (`float`, *optional*, defaults to `0.6`):
                The alpha parameter for the beta distribution.
            beta (`float`, *optional*, defaults to `0.6`):
                The beta parameter for the beta distribution.

        Returns:
            `torch.Tensor`:
                The converted sigma values following a beta distribution schedule.
        """
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

649
        sigmas = np.array(
650
651
652
653
654
655
656
657
658
659
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

660
    def convert_model_output(
661
        self,
662
        model_output: torch.Tensor,
663
        *args,
664
        sample: torch.Tensor = None,
665
        **kwargs,
666
    ) -> torch.Tensor:
667
        r"""
668
        Convert the model output to the corresponding type the UniPC algorithm needs.
669
670

        Args:
671
            model_output (`torch.Tensor`):
672
673
674
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
675
            sample (`torch.Tensor`):
676
                A current instance of a sample created by the diffusion process.
677
678

        Returns:
679
            `torch.Tensor`:
680
                The converted model output.
681
        """
682
683
684
685
686
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
687
                raise ValueError("missing `sample` as a required keyword argument")
688
689
690
691
692
693
694
695
696
697
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma = self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)

698
699
700
701
702
703
704
        if self.predict_x0:
            if self.config.prediction_type == "epsilon":
                x0_pred = (sample - sigma_t * model_output) / alpha_t
            elif self.config.prediction_type == "sample":
                x0_pred = model_output
            elif self.config.prediction_type == "v_prediction":
                x0_pred = alpha_t * sample - sigma_t * model_output
705
706
707
            elif self.config.prediction_type == "flow_prediction":
                sigma_t = self.sigmas[self.step_index]
                x0_pred = sample - sigma_t * model_output
708
709
            else:
                raise ValueError(
710
711
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
                    "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
712
713
714
                )

            if self.config.thresholding:
715
716
                x0_pred = self._threshold_sample(x0_pred)

717
718
719
720
721
722
723
724
725
726
727
728
729
            return x0_pred
        else:
            if self.config.prediction_type == "epsilon":
                return model_output
            elif self.config.prediction_type == "sample":
                epsilon = (sample - alpha_t * model_output) / sigma_t
                return epsilon
            elif self.config.prediction_type == "v_prediction":
                epsilon = alpha_t * model_output + sigma_t * sample
                return epsilon
            else:
                raise ValueError(
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
Wenliang Zhao's avatar
Wenliang Zhao committed
730
                    " `v_prediction` for the UniPCMultistepScheduler."
731
732
733
734
                )

    def multistep_uni_p_bh_update(
        self,
735
        model_output: torch.Tensor,
736
        *args,
737
        sample: torch.Tensor = None,
738
739
        order: int = None,
        **kwargs,
740
    ) -> torch.Tensor:
741
742
743
744
        """
        One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.

        Args:
745
            model_output (`torch.Tensor`):
746
747
748
                The direct output from the learned diffusion model at the current timestep.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
749
            sample (`torch.Tensor`):
750
751
752
                A current instance of a sample created by the diffusion process.
            order (`int`):
                The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
753
754

        Returns:
755
            `torch.Tensor`:
756
                The sample tensor at the previous timestep.
757
        """
758
759
760
761
762
        prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
763
                raise ValueError("missing `sample` as a required keyword argument")
764
765
766
767
        if order is None:
            if len(args) > 2:
                order = args[2]
            else:
768
                raise ValueError("missing `order` as a required keyword argument")
769
770
771
772
773
774
        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )
775
776
        model_output_list = self.model_outputs

777
        s0 = self.timestep_list[-1]
778
779
780
781
782
783
784
        m0 = model_output_list[-1]
        x = sample

        if self.solver_p:
            x_t = self.solver_p.step(model_output, s0, x).prev_sample
            return x_t

785
786
787
788
789
790
        sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
791
792
793
794
795
796
797

        h = lambda_t - lambda_s0
        device = sample.device

        rks = []
        D1s = []
        for i in range(1, order):
798
            si = self.step_index - i
799
            mi = model_output_list[-(i + 1)]
800
801
            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)

        rks.append(1.0)
        rks = torch.tensor(rks, device=device)

        R = []
        b = []

        hh = -h if self.predict_x0 else h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1

        factorial_i = 1

        if self.config.solver_type == "bh1":
            B_h = hh
        elif self.config.solver_type == "bh2":
            B_h = torch.expm1(hh)
        else:
            raise NotImplementedError()

        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i

        R = torch.stack(R)
        b = torch.tensor(b, device=device)

        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)  # (B, K)
            # for order 2, we use a simplified version
            if order == 2:
                rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
            else:
840
                rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
841
842
843
844
845
846
        else:
            D1s = None

        if self.predict_x0:
            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
            if D1s is not None:
847
                pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
848
849
850
851
852
853
            else:
                pred_res = 0
            x_t = x_t_ - alpha_t * B_h * pred_res
        else:
            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
            if D1s is not None:
854
                pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
855
856
857
858
859
860
861
862
863
            else:
                pred_res = 0
            x_t = x_t_ - sigma_t * B_h * pred_res

        x_t = x_t.to(x.dtype)
        return x_t

    def multistep_uni_c_bh_update(
        self,
864
        this_model_output: torch.Tensor,
865
        *args,
866
867
        last_sample: torch.Tensor = None,
        this_sample: torch.Tensor = None,
868
869
        order: int = None,
        **kwargs,
870
    ) -> torch.Tensor:
871
872
873
874
        """
        One step for the UniC (B(h) version).

        Args:
875
            this_model_output (`torch.Tensor`):
876
877
878
                The model outputs at `x_t`.
            this_timestep (`int`):
                The current timestep `t`.
879
            last_sample (`torch.Tensor`):
880
                The generated sample before the last predictor `x_{t-1}`.
881
            this_sample (`torch.Tensor`):
882
883
884
                The generated sample after the last predictor `x_{t}`.
            order (`int`):
                The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
885
886

        Returns:
887
            `torch.Tensor`:
888
                The corrected sample tensor at the current timestep.
889
        """
890
891
892
893
894
        this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
        if last_sample is None:
            if len(args) > 1:
                last_sample = args[1]
            else:
895
                raise ValueError("missing `last_sample` as a required keyword argument")
896
897
898
899
        if this_sample is None:
            if len(args) > 2:
                this_sample = args[2]
            else:
900
                raise ValueError("missing `this_sample` as a required keyword argument")
901
902
903
904
        if order is None:
            if len(args) > 3:
                order = args[3]
            else:
905
                raise ValueError("missing `order` as a required keyword argument")
906
907
908
909
910
911
912
        if this_timestep is not None:
            deprecate(
                "this_timestep",
                "1.0.0",
                "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

913
914
915
916
917
918
919
        model_output_list = self.model_outputs

        m0 = model_output_list[-1]
        x = last_sample
        x_t = this_sample
        model_t = this_model_output

920
921
922
923
924
925
        sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
926
927
928
929
930
931
932

        h = lambda_t - lambda_s0
        device = this_sample.device

        rks = []
        D1s = []
        for i in range(1, order):
933
            si = self.step_index - (i + 1)
934
            mi = model_output_list[-(i + 1)]
935
936
            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)

        rks.append(1.0)
        rks = torch.tensor(rks, device=device)

        R = []
        b = []

        hh = -h if self.predict_x0 else h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1

        factorial_i = 1

        if self.config.solver_type == "bh1":
            B_h = hh
        elif self.config.solver_type == "bh2":
            B_h = torch.expm1(hh)
        else:
            raise NotImplementedError()

        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i

        R = torch.stack(R)
        b = torch.tensor(b, device=device)

        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)
        else:
            D1s = None

        # for order 1, we use a simplified version
        if order == 1:
            rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
        else:
978
            rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
979
980
981
982

        if self.predict_x0:
            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
            if D1s is not None:
983
                corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
984
985
986
987
988
989
990
            else:
                corr_res = 0
            D1_t = model_t - m0
            x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
        else:
            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
            if D1s is not None:
991
                corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
992
993
994
995
996
997
998
            else:
                corr_res = 0
            D1_t = model_t - m0
            x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
        x_t = x_t.to(x.dtype)
        return x_t

999
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
    def index_for_timestep(
        self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
    ) -> int:
        """
        Find the index for a given timestep in the schedule.

        Args:
            timestep (`int` or `torch.Tensor`):
                The timestep for which to find the index.
            schedule_timesteps (`torch.Tensor`, *optional*):
                The timestep schedule to search in. If `None`, uses `self.timesteps`.

        Returns:
            `int`:
                The index of the timestep in the schedule.
        """
1016
1017
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
1018

1019
        index_candidates = (schedule_timesteps == timestep).nonzero()
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031

        if len(index_candidates) == 0:
            step_index = len(self.timesteps) - 1
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        elif len(index_candidates) > 1:
            step_index = index_candidates[1].item()
        else:
            step_index = index_candidates[0].item()

1032
1033
1034
        return step_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
1035
    def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
1036
1037
        """
        Initialize the step_index counter for the scheduler.
1038
1039
1040
1041

        Args:
            timestep (`int` or `torch.Tensor`):
                The current timestep for which to initialize the step index.
1042
1043
1044
1045
1046
1047
1048
1049
        """

        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
1050

1051
1052
    def step(
        self,
1053
        model_output: torch.Tensor,
1054
        timestep: Union[int, torch.Tensor],
1055
        sample: torch.Tensor,
1056
1057
1058
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
1059
1060
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the multistep UniPC.
1061
1062

        Args:
1063
            model_output (`torch.Tensor`):
1064
                The direct output from learned diffusion model.
1065
            timestep (`int` or `torch.Tensor`):
1066
                The current discrete timestep in the diffusion chain.
1067
            sample (`torch.Tensor`):
1068
                A current instance of a sample created by the diffusion process.
1069
            return_dict (`bool`, defaults to `True`):
1070
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
1071
1072

        Returns:
1073
1074
1075
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
1076
1077
1078
1079
1080
1081
1082

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

1083
1084
        if self.step_index is None:
            self._init_step_index(timestep)
1085
1086

        use_corrector = (
1087
            self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
1088
1089
        )

1090
        model_output_convert = self.convert_model_output(model_output, sample=sample)
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
        if use_corrector:
            sample = self.multistep_uni_c_bh_update(
                this_model_output=model_output_convert,
                last_sample=self.last_sample,
                this_sample=sample,
                order=self.this_order,
            )

        for i in range(self.config.solver_order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
            self.timestep_list[i] = self.timestep_list[i + 1]

        self.model_outputs[-1] = model_output_convert
        self.timestep_list[-1] = timestep

        if self.config.lower_order_final:
1107
            this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
        else:
            this_order = self.config.solver_order

        self.this_order = min(this_order, self.lower_order_nums + 1)  # warmup for multistep
        assert self.this_order > 0

        self.last_sample = sample
        prev_sample = self.multistep_uni_p_bh_update(
            model_output=model_output,  # pass the original non-converted model output, in case solver-p is used
            sample=sample,
            order=self.this_order,
        )

        if self.lower_order_nums < self.config.solver_order:
            self.lower_order_nums += 1

1124
1125
1126
        # upon completion increase step index by one
        self._step_index += 1

1127
1128
1129
1130
1131
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

1132
    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1133
1134
1135
1136
1137
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
1138
            sample (`torch.Tensor`):
1139
                The input sample.
1140
1141

        Returns:
1142
            `torch.Tensor`:
1143
                A scaled input sample.
1144
1145
1146
        """
        return sample

1147
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
1148
1149
    def add_noise(
        self,
1150
1151
        original_samples: torch.Tensor,
        noise: torch.Tensor,
1152
        timesteps: torch.IntTensor,
1153
    ) -> torch.Tensor:
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
        """
        Add noise to the original samples according to the noise schedule at the specified timesteps.

        Args:
            original_samples (`torch.Tensor`):
                The original samples without noise.
            noise (`torch.Tensor`):
                The noise to add to the samples.
            timesteps (`torch.IntTensor`):
                The timesteps at which to add noise to the samples.

        Returns:
            `torch.Tensor`:
                The noisy samples.
        """
1169
1170
1171
1172
1173
1174
1175
1176
1177
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
            schedule_timesteps = self.timesteps.to(original_samples.device)
            timesteps = timesteps.to(original_samples.device)
1178

1179
        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
1180
1181
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1182
1183
1184
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
1185
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
1186
            # add noise is called before first denoising step to create initial latent(img2img)
1187
            step_indices = [self.begin_index] * timesteps.shape[0]
1188

1189
1190
1191
        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)
1192

1193
1194
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
1195
1196
        return noisy_samples

1197
    def __len__(self) -> int:
1198
        return self.config.num_train_timesteps