"vscode:/vscode.git/clone" did not exist on "cf0c24150a0536748faadc8703e4c70c2f7d07b6"
scheduling_ddim_parallel.py 32.6 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

import math
from dataclasses import dataclass
20
from typing import List, Literal, Optional, Tuple, Union
21
22
23
24
25

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
Dhruv Nair's avatar
Dhruv Nair committed
26
27
from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
28
29
30
31
32
33
34
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput
class DDIMParallelSchedulerOutput(BaseOutput):
    """
35
    Output class for the scheduler's `step` function output.
36
37

    Args:
38
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
            denoising loop.
41
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
44
45
            `pred_original_sample` can be used to preview progress or for guidance.
    """

46
47
    prev_sample: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
48
49
50


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
51
def betas_for_alpha_bar(
52
53
54
55
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
56
57
58
59
60
61
62
63
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
64
65
66
67
68
69
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
70
71

    Returns:
72
73
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
74
    """
YiYi Xu's avatar
YiYi Xu committed
75
    if alpha_transform_type == "cosine":
76

YiYi Xu's avatar
YiYi Xu committed
77
78
79
80
81
82
83
84
85
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
86
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
87
88
89
90
91

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
92
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
94
95
96
97
98
    return torch.tensor(betas, dtype=torch.float32)


# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
    """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
99
    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
100
101

    Args:
102
        betas (`torch.Tensor`):
103
            The betas that the scheduler is being initialized with.
104
105

    Returns:
106
107
        `torch.Tensor`:
            Rescaled betas with zero terminal SNR.
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
    """
    Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
    diffusion probabilistic models (DDPMs) with non-Markovian guidance.

    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.

Quentin Gallouédec's avatar
Quentin Gallouédec committed
143
    For more details, see the original paper: https://huggingface.co/papers/2010.02502
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
        clip_sample (`bool`, default `True`):
            option to clip predicted sample for numerical stability.
        clip_sample_range (`float`, default `1.0`):
            the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
        set_alpha_to_one (`bool`, default `True`):
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        steps_offset (`int`, default `0`):
163
            An offset added to the inference steps, as required by some model families.
164
165
166
167
168
        prediction_type (`str`, default `epsilon`, optional):
            prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
            process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
            https://imagen.research.google/video/paper.pdf)
        thresholding (`bool`, default `False`):
Quentin Gallouédec's avatar
Quentin Gallouédec committed
169
170
171
            whether to use the "dynamic thresholding" method (introduced by Imagen,
            https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
            diffusion models (such as stable-diffusion).
172
173
        dynamic_thresholding_ratio (`float`, default `0.995`):
            the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
Quentin Gallouédec's avatar
Quentin Gallouédec committed
174
            (https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
175
176
177
178
        sample_max_value (`float`, default `1.0`):
            the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
        timestep_spacing (`str`, default `"leading"`):
            The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Quentin Gallouédec's avatar
Quentin Gallouédec committed
179
            Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
180
        rescale_betas_zero_snr (`bool`, default `False`):
Quentin Gallouédec's avatar
Quentin Gallouédec committed
181
182
183
            whether to rescale the betas to have zero terminal SNR (proposed by
            https://huggingface.co/papers/2305.08891). This can enable the model to generate very bright and dark
            samples instead of limiting it to samples with medium brightness. Loosely related to
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
    """

    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
    order = 1
    _is_ode_scheduler = True

    @register_to_config
    # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.__init__
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
198
        beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
199
200
201
202
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
        clip_sample: bool = True,
        set_alpha_to_one: bool = True,
        steps_offset: int = 0,
203
        prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
204
205
206
207
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        clip_sample_range: float = 1.0,
        sample_max_value: float = 1.0,
208
        timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
209
210
211
212
213
214
215
216
        rescale_betas_zero_snr: bool = False,
    ):
        if trained_betas is not None:
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
217
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
218
219
220
221
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
222
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        # Rescale for zero SNR
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        # At every step in ddim, we are looking into the previous alphas_cumprod
        # For the final step, there is no previous alphas_cumprod because we are already at 0
        # `set_alpha_to_one` decides whether we set this parameter simply to one or
        # whether we use the final alpha of the "non-previous" one.
        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # setable values
        self.num_inference_steps = None
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))

    # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
245
    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
246
247
248
249
250
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
251
            sample (`torch.Tensor`):
252
253
254
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
255
256

        Returns:
257
            `torch.Tensor`:
258
                A scaled input sample.
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        """
        return sample

    def _get_variance(self, timestep, prev_timestep=None):
        if prev_timestep is None:
            prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

    def _batch_get_variance(self, t, prev_t):
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
        alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
287
    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
288
        """
289
290
        Apply dynamic thresholding to the predicted sample.

291
292
293
294
295
296
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

Quentin Gallouédec's avatar
Quentin Gallouédec committed
297
        https://huggingface.co/papers/2205.11487
298
299
300
301
302
303
304
305

        Args:
            sample (`torch.Tensor`):
                The predicted sample to be thresholded.

        Returns:
            `torch.Tensor`:
                The thresholded sample.
306
307
        """
        dtype = sample.dtype
308
        batch_size, channels, *remaining_dims = sample.shape
309
310
311
312
313

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
314
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
315
316
317
318
319
320
321
322
323
324

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

325
        sample = sample.reshape(batch_size, channels, *remaining_dims)
326
327
328
329
330
331
332
        sample = sample.to(dtype)

        return sample

    # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
        """
333
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
334
335
336

        Args:
            num_inference_steps (`int`):
337
                The number of diffusion steps used when generating samples with a pre-trained model.
338
339
340
341
342
            device (`Union[str, torch.device]`, *optional*):
                The device to use for the timesteps.

        Raises:
            ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
343
344
345
346
347
348
349
350
351
352
353
        """

        if num_inference_steps > self.config.num_train_timesteps:
            raise ValueError(
                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
                f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
                f" maximal {self.config.num_train_timesteps} timesteps."
            )

        self.num_inference_steps = num_inference_steps

Quentin Gallouédec's avatar
Quentin Gallouédec committed
354
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
355
356
357
358
359
360
361
362
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
                .round()[::-1]
                .copy()
                .astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
            )

        self.timesteps = torch.from_numpy(timesteps).to(device)

    def step(
        self,
383
        model_output: torch.Tensor,
384
        timestep: int,
385
        sample: torch.Tensor,
386
387
388
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
        generator=None,
389
        variance_noise: Optional[torch.Tensor] = None,
390
391
392
393
394
395
396
        return_dict: bool = True,
    ) -> Union[DDIMParallelSchedulerOutput, Tuple]:
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
397
            model_output (`torch.Tensor`): direct output from learned diffusion model.
398
            timestep (`int`): current discrete timestep in the diffusion chain.
399
            sample (`torch.Tensor`):
400
401
402
403
404
405
406
                current instance of sample being created by diffusion process.
            eta (`float`): weight of noise for added noise in diffusion step.
            use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
                predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
                `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
                coincide with the one provided as input and `use_clipped_model_output` will have not effect.
            generator: random number generator.
407
            variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
408
                can directly provide the noise for the variance itself. This is useful for methods such as
Quentin Gallouédec's avatar
Quentin Gallouédec committed
409
                CycleDiffusion. (https://huggingface.co/papers/2210.05559)
410
411
412
413
414
415
416
417
418
419
420
421
422
            return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class

        Returns:
            [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
            When returning a tuple, the first element is the sample tensor.

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Quentin Gallouédec's avatar
Quentin Gallouédec committed
423
        # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
        # - pred_original_sample -> f_theta(x_t, t) or x_0
        # - std_dev_t -> sigma_t
        # - eta -> η
        # - pred_sample_direction -> "direction pointing to x_t"
        # - pred_prev_sample -> "x_t-1"

        # 1. get previous step value (=t-1)
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

        # 2. compute alphas, betas
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

        beta_prod_t = 1 - alpha_prod_t

        # 3. compute predicted original sample from predicted noise also called
Quentin Gallouédec's avatar
Quentin Gallouédec committed
444
        # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        if self.config.prediction_type == "epsilon":
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
            pred_epsilon = model_output
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
        elif self.config.prediction_type == "v_prediction":
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                " `v_prediction`"
            )

        # 4. Clip or threshold "predicted x_0"
        if self.config.thresholding:
            pred_original_sample = self._threshold_sample(pred_original_sample)
        elif self.config.clip_sample:
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
        variance = self._get_variance(timestep, prev_timestep)
        std_dev_t = eta * variance ** (0.5)

        if use_clipped_model_output:
            # the pred_epsilon is always re-derived from the clipped x_0 in Glide
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)

Quentin Gallouédec's avatar
Quentin Gallouédec committed
477
        # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
478
479
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon

Quentin Gallouédec's avatar
Quentin Gallouédec committed
480
        # 7. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
            if variance_noise is not None and generator is not None:
                raise ValueError(
                    "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
                    " `variance_noise` stays `None`."
                )

            if variance_noise is None:
                variance_noise = randn_tensor(
                    model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
                )
            variance = std_dev_t * variance_noise

            prev_sample = prev_sample + variance

        if not return_dict:
499
500
501
502
            return (
                prev_sample,
                pred_original_sample,
            )
503
504
505
506
507

        return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

    def batch_step_no_noise(
        self,
508
        model_output: torch.Tensor,
509
        timesteps: List[int],
510
        sample: torch.Tensor,
511
512
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
513
    ) -> torch.Tensor:
514
515
516
517
518
519
520
521
522
        """
        Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once.
        Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise
        is pre-sampled by the pipeline.

        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
523
            model_output (`torch.Tensor`): direct output from learned diffusion model.
524
525
            timesteps (`List[int]`):
                current discrete timesteps in the diffusion chain. This is now a list of integers.
526
            sample (`torch.Tensor`):
527
528
529
530
531
532
533
534
                current instance of sample being created by diffusion process.
            eta (`float`): weight of noise for added noise in diffusion step.
            use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
                predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
                `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
                coincide with the one provided as input and `use_clipped_model_output` will have not effect.

        Returns:
535
            `torch.Tensor`: sample tensor at previous timestep.
536
537
538
539
540
541
542
543
544

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

        assert eta == 0.0

Quentin Gallouédec's avatar
Quentin Gallouédec committed
545
        # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
        # - pred_original_sample -> f_theta(x_t, t) or x_0
        # - std_dev_t -> sigma_t
        # - eta -> η
        # - pred_sample_direction -> "direction pointing to x_t"
        # - pred_prev_sample -> "x_t-1"

        # 1. get previous step value (=t-1)
        t = timesteps
        prev_t = t - self.config.num_train_timesteps // self.num_inference_steps

        t = t.view(-1, *([1] * (model_output.ndim - 1)))
        prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))

        # 1. compute alphas, betas
        self.alphas_cumprod = self.alphas_cumprod.to(model_output.device)
        self.final_alpha_cumprod = self.final_alpha_cumprod.to(model_output.device)
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
        alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)

        beta_prod_t = 1 - alpha_prod_t

        # 3. compute predicted original sample from predicted noise also called
Quentin Gallouédec's avatar
Quentin Gallouédec committed
573
        # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        if self.config.prediction_type == "epsilon":
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
            pred_epsilon = model_output
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
        elif self.config.prediction_type == "v_prediction":
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                " `v_prediction`"
            )

        # 4. Clip or threshold "predicted x_0"
        if self.config.thresholding:
            pred_original_sample = self._threshold_sample(pred_original_sample)
        elif self.config.clip_sample:
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
        variance = self._batch_get_variance(t, prev_t).to(model_output.device).view(*alpha_prod_t_prev.shape)
        std_dev_t = eta * variance ** (0.5)

        if use_clipped_model_output:
            # the pred_epsilon is always re-derived from the clipped x_0 in Glide
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)

Quentin Gallouédec's avatar
Quentin Gallouédec committed
606
        # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
607
608
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon

Quentin Gallouédec's avatar
Quentin Gallouédec committed
609
        # 7. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
610
611
612
613
614
615
616
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        return prev_sample

    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
    def add_noise(
        self,
617
618
        original_samples: torch.Tensor,
        noise: torch.Tensor,
619
        timesteps: torch.IntTensor,
620
    ) -> torch.Tensor:
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        """
        Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
        diffusion process).

        Args:
            original_samples (`torch.Tensor`):
                The original samples to which noise will be added.
            noise (`torch.Tensor`):
                The noise to add to the samples.
            timesteps (`torch.IntTensor`):
                The timesteps indicating the noise level for each sample.

        Returns:
            `torch.Tensor`:
                The noisy samples.
        """
637
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
638
639
640
641
        # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
        # for the subsequent add_noise calls
        self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
        alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
        timesteps = timesteps.to(original_samples.device)

        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
658
    def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
        """
        Compute the velocity prediction from the sample and noise according to the velocity formula.

        Args:
            sample (`torch.Tensor`):
                The input sample.
            noise (`torch.Tensor`):
                The noise tensor.
            timesteps (`torch.IntTensor`):
                The timesteps for velocity computation.

        Returns:
            `torch.Tensor`:
                The computed velocity.
        """
674
        # Make sure alphas_cumprod and timestep have same device and dtype as sample
675
676
        self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
        alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
        timesteps = timesteps.to(sample.device)

        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(sample.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
        return velocity

    def __len__(self):
        return self.config.num_train_timesteps