scheduling_unipc_multistep.py 46.1 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Quentin Gallouédec's avatar
Quentin Gallouédec committed
15
# DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info
Wenliang Zhao's avatar
Wenliang Zhao committed
16
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
17
18
19
20
21
22
23
24

import math
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
25
from ..utils import deprecate, is_scipy_available
26
27
28
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


29
30
31
32
if is_scipy_available():
    import scipy.stats


33
34
35
36
37
38
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
39
40
41
42
43
44
45
46
47
48
49
50
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
51
52
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
53
54
55
56

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
57
    if alpha_transform_type == "cosine":
58

59
60
61
62
63
64
65
66
67
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
68
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
69
70
71
72
73

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
74
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
75
76
77
    return torch.tensor(betas, dtype=torch.float32)


78
79
80
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
    """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
81
    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
82
83
84


    Args:
85
        betas (`torch.Tensor`):
86
87
88
            the betas that the scheduler is being initialized with.

    Returns:
89
        `torch.Tensor`: rescaled betas with zero terminal SNR
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


115
116
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
    """
117
    `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
118

119
120
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
121
122

    Args:
123
124
125
126
127
128
129
130
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
131
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
132
133
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
134
        solver_order (`int`, default `2`):
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
            due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
            unconditional sampling.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
        predict_x0 (`bool`, defaults to `True`):
            Whether to use the updating algorithm on the predicted x0.
151
        solver_type (`str`, default `bh2`):
152
            Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
153
154
            otherwise.
        lower_order_final (`bool`, default `True`):
155
156
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
157
        disable_corrector (`list`, default `[]`):
158
159
160
            Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
            and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
            usually disabled during the first few steps.
161
        solver_p (`SchedulerMixin`, default `None`):
162
            Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
163
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
164
165
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
166
167
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
168
169
170
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
171
172
        use_flow_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
173
174
175
176
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
177
            An offset added to the inference steps, as required by some model families.
178
        final_sigmas_type (`str`, defaults to `"zero"`):
179
180
            The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
            sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
181
182
183
184
        rescale_betas_zero_snr (`bool`, defaults to `False`):
            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
            dark samples instead of limiting it to samples with medium brightness. Loosely related to
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    """

    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
        solver_order: int = 2,
        prediction_type: str = "epsilon",
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        predict_x0: bool = True,
Wenliang Zhao's avatar
Wenliang Zhao committed
204
        solver_type: str = "bh2",
205
206
207
        lower_order_final: bool = True,
        disable_corrector: List[int] = [],
        solver_p: SchedulerMixin = None,
208
        use_karras_sigmas: Optional[bool] = False,
209
        use_exponential_sigmas: Optional[bool] = False,
210
        use_beta_sigmas: Optional[bool] = False,
211
212
        use_flow_sigmas: Optional[bool] = False,
        flow_shift: Optional[float] = 1.0,
213
214
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
215
        final_sigmas_type: Optional[str] = "zero",  # "zero", "sigma_min"
216
        rescale_betas_zero_snr: bool = False,
217
218
        use_dynamic_shifting: bool = False,
        time_shift_type: str = "exponential",
219
    ):
220
221
222
223
224
225
        if self.config.use_beta_sigmas and not is_scipy_available():
            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
226
227
228
229
230
231
        if trained_betas is not None:
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
232
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
233
234
235
236
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
237
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
238

239
240
241
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

242
243
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
244
245
246
247
248
249

        if rescale_betas_zero_snr:
            # Close to 0 without being 0 so first sigma is not inf
            # FP16 smallest positive subnormal works well here
            self.alphas_cumprod[-1] = 2**-24

250
251
252
253
        # Currently we only support VP-type noise schedule
        self.alpha_t = torch.sqrt(self.alphas_cumprod)
        self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
        self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
254
        self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
255
256
257
258
259
260

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        if solver_type not in ["bh1", "bh2"]:
            if solver_type in ["midpoint", "heun", "logrho"]:
261
                self.register_to_config(solver_type="bh2")
262
            else:
263
                raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
264
265
266
267
268
269
270
271
272
273
274
275

        self.predict_x0 = predict_x0
        # setable values
        self.num_inference_steps = None
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
        self.model_outputs = [None] * solver_order
        self.timestep_list = [None] * solver_order
        self.lower_order_nums = 0
        self.disable_corrector = disable_corrector
        self.solver_p = solver_p
        self.last_sample = None
276
        self._step_index = None
277
        self._begin_index = None
278
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
279
280
281
282

    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
283
        The index counter for current timestep. It will increase 1 after each scheduler step.
284
285
        """
        return self._step_index
286

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

305
306
307
    def set_timesteps(
        self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
    ):
308
        """
309
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
310
311
312

        Args:
            num_inference_steps (`int`):
313
314
315
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
316
        """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
317
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
318
319
320
        if mu is not None:
            assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
            self.config.flow_shift = np.exp(mu)
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
                .round()[::-1][:-1]
                .copy()
                .astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
344

345
346
347
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        if self.config.use_karras_sigmas:
            log_sigmas = np.log(sigmas)
348
            sigmas = np.flip(sigmas).copy()
349
350
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
351
352
353
354
355
356
357
358
359
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
360
        elif self.config.use_exponential_sigmas:
361
362
363
            log_sigmas = np.log(sigmas)
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
364
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
365
366
367
368
369
370
371
372
373
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
374
        elif self.config.use_beta_sigmas:
375
376
377
            log_sigmas = np.log(sigmas)
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
378
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
379
380
381
382
383
384
385
386
387
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
388
389
390
        elif self.config.use_flow_sigmas:
            alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
            sigmas = 1.0 - alphas
hlky's avatar
hlky committed
391
            sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
392
            timesteps = (sigmas * self.config.num_train_timesteps).copy()
393
394
395
396
397
398
399
400
401
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = sigmas[-1]
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
402
403
        else:
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
404
405
406
407
408
409
410
411
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )
412
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
413

414
415
        self.sigmas = torch.from_numpy(sigmas)
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
416
417
418

        self.num_inference_steps = len(timesteps)

419
420
421
422
423
424
        self.model_outputs = [
            None,
        ] * self.config.solver_order
        self.lower_order_nums = 0
        self.last_sample = None
        if self.solver_p:
425
            self.solver_p.set_timesteps(self.num_inference_steps, device=device)
426

427
428
        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
429
        self._begin_index = None
430
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
431

432
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
433
    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
434
435
436
437
438
439
440
        """
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

Quentin Gallouédec's avatar
Quentin Gallouédec committed
441
        https://huggingface.co/papers/2205.11487
442
443
        """
        dtype = sample.dtype
444
        batch_size, channels, *remaining_dims = sample.shape
445
446
447
448
449

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
450
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
451
452
453
454
455
456
457
458
459
460

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

461
        sample = sample.reshape(batch_size, channels, *remaining_dims)
462
463
464
        sample = sample.to(dtype)

        return sample
465

466
467
468
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
469
        log_sigma = np.log(np.maximum(sigma, 1e-10))
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

490
491
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
    def _sigma_to_alpha_sigma_t(self, sigma):
492
493
494
495
496
497
        if self.config.use_flow_sigmas:
            alpha_t = 1 - sigma
            sigma_t = sigma
        else:
            alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
            sigma_t = sigma * alpha_t
498
499
500

        return alpha_t, sigma_t

501
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
502
    def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
503
504
        """Constructs the noise schedule of Karras et al. (2022)."""

Suraj Patil's avatar
Suraj Patil committed
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
519
520
521
522
523
524
525
526

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
        """Constructs an exponential noise schedule."""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

546
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
547
548
        return sigmas

549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
        """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

570
        sigmas = np.array(
571
572
573
574
575
576
577
578
579
580
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

581
    def convert_model_output(
582
        self,
583
        model_output: torch.Tensor,
584
        *args,
585
        sample: torch.Tensor = None,
586
        **kwargs,
587
    ) -> torch.Tensor:
588
        r"""
589
        Convert the model output to the corresponding type the UniPC algorithm needs.
590
591

        Args:
592
            model_output (`torch.Tensor`):
593
594
595
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
596
            sample (`torch.Tensor`):
597
                A current instance of a sample created by the diffusion process.
598
599

        Returns:
600
            `torch.Tensor`:
601
                The converted model output.
602
        """
603
604
605
606
607
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
608
                raise ValueError("missing `sample` as a required keyword argument")
609
610
611
612
613
614
615
616
617
618
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma = self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)

619
620
621
622
623
624
625
        if self.predict_x0:
            if self.config.prediction_type == "epsilon":
                x0_pred = (sample - sigma_t * model_output) / alpha_t
            elif self.config.prediction_type == "sample":
                x0_pred = model_output
            elif self.config.prediction_type == "v_prediction":
                x0_pred = alpha_t * sample - sigma_t * model_output
626
627
628
            elif self.config.prediction_type == "flow_prediction":
                sigma_t = self.sigmas[self.step_index]
                x0_pred = sample - sigma_t * model_output
629
630
            else:
                raise ValueError(
631
632
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
                    "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
633
634
635
                )

            if self.config.thresholding:
636
637
                x0_pred = self._threshold_sample(x0_pred)

638
639
640
641
642
643
644
645
646
647
648
649
650
            return x0_pred
        else:
            if self.config.prediction_type == "epsilon":
                return model_output
            elif self.config.prediction_type == "sample":
                epsilon = (sample - alpha_t * model_output) / sigma_t
                return epsilon
            elif self.config.prediction_type == "v_prediction":
                epsilon = alpha_t * model_output + sigma_t * sample
                return epsilon
            else:
                raise ValueError(
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
Wenliang Zhao's avatar
Wenliang Zhao committed
651
                    " `v_prediction` for the UniPCMultistepScheduler."
652
653
654
655
                )

    def multistep_uni_p_bh_update(
        self,
656
        model_output: torch.Tensor,
657
        *args,
658
        sample: torch.Tensor = None,
659
660
        order: int = None,
        **kwargs,
661
    ) -> torch.Tensor:
662
663
664
665
        """
        One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.

        Args:
666
            model_output (`torch.Tensor`):
667
668
669
                The direct output from the learned diffusion model at the current timestep.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
670
            sample (`torch.Tensor`):
671
672
673
                A current instance of a sample created by the diffusion process.
            order (`int`):
                The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
674
675

        Returns:
676
            `torch.Tensor`:
677
                The sample tensor at the previous timestep.
678
        """
679
680
681
682
683
        prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
684
                raise ValueError("missing `sample` as a required keyword argument")
685
686
687
688
        if order is None:
            if len(args) > 2:
                order = args[2]
            else:
689
                raise ValueError("missing `order` as a required keyword argument")
690
691
692
693
694
695
        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )
696
697
        model_output_list = self.model_outputs

698
        s0 = self.timestep_list[-1]
699
700
701
702
703
704
705
        m0 = model_output_list[-1]
        x = sample

        if self.solver_p:
            x_t = self.solver_p.step(model_output, s0, x).prev_sample
            return x_t

706
707
708
709
710
711
        sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
712
713
714
715
716
717
718

        h = lambda_t - lambda_s0
        device = sample.device

        rks = []
        D1s = []
        for i in range(1, order):
719
            si = self.step_index - i
720
            mi = model_output_list[-(i + 1)]
721
722
            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)

        rks.append(1.0)
        rks = torch.tensor(rks, device=device)

        R = []
        b = []

        hh = -h if self.predict_x0 else h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1

        factorial_i = 1

        if self.config.solver_type == "bh1":
            B_h = hh
        elif self.config.solver_type == "bh2":
            B_h = torch.expm1(hh)
        else:
            raise NotImplementedError()

        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i

        R = torch.stack(R)
        b = torch.tensor(b, device=device)

        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)  # (B, K)
            # for order 2, we use a simplified version
            if order == 2:
                rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
            else:
761
                rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
762
763
764
765
766
767
        else:
            D1s = None

        if self.predict_x0:
            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
            if D1s is not None:
768
                pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
769
770
771
772
773
774
            else:
                pred_res = 0
            x_t = x_t_ - alpha_t * B_h * pred_res
        else:
            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
            if D1s is not None:
775
                pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
776
777
778
779
780
781
782
783
784
            else:
                pred_res = 0
            x_t = x_t_ - sigma_t * B_h * pred_res

        x_t = x_t.to(x.dtype)
        return x_t

    def multistep_uni_c_bh_update(
        self,
785
        this_model_output: torch.Tensor,
786
        *args,
787
788
        last_sample: torch.Tensor = None,
        this_sample: torch.Tensor = None,
789
790
        order: int = None,
        **kwargs,
791
    ) -> torch.Tensor:
792
793
794
795
        """
        One step for the UniC (B(h) version).

        Args:
796
            this_model_output (`torch.Tensor`):
797
798
799
                The model outputs at `x_t`.
            this_timestep (`int`):
                The current timestep `t`.
800
            last_sample (`torch.Tensor`):
801
                The generated sample before the last predictor `x_{t-1}`.
802
            this_sample (`torch.Tensor`):
803
804
805
                The generated sample after the last predictor `x_{t}`.
            order (`int`):
                The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
806
807

        Returns:
808
            `torch.Tensor`:
809
                The corrected sample tensor at the current timestep.
810
        """
811
812
813
814
815
        this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
        if last_sample is None:
            if len(args) > 1:
                last_sample = args[1]
            else:
816
                raise ValueError("missing `last_sample` as a required keyword argument")
817
818
819
820
        if this_sample is None:
            if len(args) > 2:
                this_sample = args[2]
            else:
821
                raise ValueError("missing `this_sample` as a required keyword argument")
822
823
824
825
        if order is None:
            if len(args) > 3:
                order = args[3]
            else:
826
                raise ValueError("missing `order` as a required keyword argument")
827
828
829
830
831
832
833
        if this_timestep is not None:
            deprecate(
                "this_timestep",
                "1.0.0",
                "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

834
835
836
837
838
839
840
        model_output_list = self.model_outputs

        m0 = model_output_list[-1]
        x = last_sample
        x_t = this_sample
        model_t = this_model_output

841
842
843
844
845
846
        sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
847
848
849
850
851
852
853

        h = lambda_t - lambda_s0
        device = this_sample.device

        rks = []
        D1s = []
        for i in range(1, order):
854
            si = self.step_index - (i + 1)
855
            mi = model_output_list[-(i + 1)]
856
857
            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)

        rks.append(1.0)
        rks = torch.tensor(rks, device=device)

        R = []
        b = []

        hh = -h if self.predict_x0 else h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1

        factorial_i = 1

        if self.config.solver_type == "bh1":
            B_h = hh
        elif self.config.solver_type == "bh2":
            B_h = torch.expm1(hh)
        else:
            raise NotImplementedError()

        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i

        R = torch.stack(R)
        b = torch.tensor(b, device=device)

        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)
        else:
            D1s = None

        # for order 1, we use a simplified version
        if order == 1:
            rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
        else:
899
            rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
900
901
902
903

        if self.predict_x0:
            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
            if D1s is not None:
904
                corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
905
906
907
908
909
910
911
            else:
                corr_res = 0
            D1_t = model_t - m0
            x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
        else:
            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
            if D1s is not None:
912
                corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
913
914
915
916
917
918
919
            else:
                corr_res = 0
            D1_t = model_t - m0
            x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
        x_t = x_t.to(x.dtype)
        return x_t

920
921
922
923
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
924

925
        index_candidates = (schedule_timesteps == timestep).nonzero()
926
927
928
929
930
931
932
933
934
935
936
937

        if len(index_candidates) == 0:
            step_index = len(self.timesteps) - 1
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        elif len(index_candidates) > 1:
            step_index = index_candidates[1].item()
        else:
            step_index = index_candidates[0].item()

938
939
940
941
942
943
944
945
946
947
948
949
950
951
        return step_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
    def _init_step_index(self, timestep):
        """
        Initialize the step_index counter for the scheduler.
        """

        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
952

953
954
    def step(
        self,
955
        model_output: torch.Tensor,
956
        timestep: Union[int, torch.Tensor],
957
        sample: torch.Tensor,
958
959
960
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
961
962
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the multistep UniPC.
963
964

        Args:
965
            model_output (`torch.Tensor`):
966
967
968
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
969
            sample (`torch.Tensor`):
970
971
972
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
973
974

        Returns:
975
976
977
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
978
979
980
981
982
983
984

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

985
986
        if self.step_index is None:
            self._init_step_index(timestep)
987
988

        use_corrector = (
989
            self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
990
991
        )

992
        model_output_convert = self.convert_model_output(model_output, sample=sample)
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        if use_corrector:
            sample = self.multistep_uni_c_bh_update(
                this_model_output=model_output_convert,
                last_sample=self.last_sample,
                this_sample=sample,
                order=self.this_order,
            )

        for i in range(self.config.solver_order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
            self.timestep_list[i] = self.timestep_list[i + 1]

        self.model_outputs[-1] = model_output_convert
        self.timestep_list[-1] = timestep

        if self.config.lower_order_final:
1009
            this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
        else:
            this_order = self.config.solver_order

        self.this_order = min(this_order, self.lower_order_nums + 1)  # warmup for multistep
        assert self.this_order > 0

        self.last_sample = sample
        prev_sample = self.multistep_uni_p_bh_update(
            model_output=model_output,  # pass the original non-converted model output, in case solver-p is used
            sample=sample,
            order=self.this_order,
        )

        if self.lower_order_nums < self.config.solver_order:
            self.lower_order_nums += 1

1026
1027
1028
        # upon completion increase step index by one
        self._step_index += 1

1029
1030
1031
1032
1033
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

1034
    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1035
1036
1037
1038
1039
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
1040
            sample (`torch.Tensor`):
1041
                The input sample.
1042
1043

        Returns:
1044
            `torch.Tensor`:
1045
                A scaled input sample.
1046
1047
1048
        """
        return sample

1049
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
1050
1051
    def add_noise(
        self,
1052
1053
        original_samples: torch.Tensor,
        noise: torch.Tensor,
1054
        timesteps: torch.IntTensor,
1055
    ) -> torch.Tensor:
1056
1057
1058
1059
1060
1061
1062
1063
1064
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
            schedule_timesteps = self.timesteps.to(original_samples.device)
            timesteps = timesteps.to(original_samples.device)
1065

1066
        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
1067
1068
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1069
1070
1071
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
1072
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
1073
            # add noise is called before first denoising step to create initial latent(img2img)
1074
            step_indices = [self.begin_index] * timesteps.shape[0]
1075

1076
1077
1078
        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)
1079

1080
1081
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
1082
1083
1084
1085
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps