scheduling_ddim.py 27.3 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16
17

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

Patrick von Platen's avatar
Patrick von Platen committed
18
import math
19
from dataclasses import dataclass
20
from typing import List, Literal, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
21

Patrick von Platen's avatar
Patrick von Platen committed
22
import numpy as np
23
import torch
Patrick von Platen's avatar
Patrick von Platen committed
24

25
from ..configuration_utils import ConfigMixin, register_to_config
Dhruv Nair's avatar
Dhruv Nair committed
26
27
from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
Kashif Rasul's avatar
Kashif Rasul committed
28
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
29
30
31


@dataclass
32
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
33
34
class DDIMSchedulerOutput(BaseOutput):
    """
35
    Output class for the scheduler's `step` function output.
36
37

    Args:
38
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
            denoising loop.
41
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
44
45
            `pred_original_sample` can be used to preview progress or for guidance.
    """

46
47
    prev_sample: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
48
49


50
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
51
def betas_for_alpha_bar(
52
53
54
55
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
56
    """
Patrick von Platen's avatar
Patrick von Platen committed
57
58
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
59

60
61
62
63
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
64
65
66
67
68
69
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
70
71

    Returns:
72
73
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
74
    """
YiYi Xu's avatar
YiYi Xu committed
75
    if alpha_transform_type == "cosine":
76

YiYi Xu's avatar
YiYi Xu committed
77
78
79
80
81
82
83
84
85
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
86
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
87

88
89
90
91
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
92
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
    return torch.tensor(betas, dtype=torch.float32)
Patrick von Platen's avatar
Patrick von Platen committed
94
95


96
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
97
    """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
98
    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
99
100

    Args:
101
        betas (`torch.Tensor`):
102
103
104
            the betas that the scheduler is being initialized with.

    Returns:
105
        `torch.Tensor`: rescaled betas with zero terminal SNR
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


Patrick von Platen's avatar
Patrick von Platen committed
131
class DDIMScheduler(SchedulerMixin, ConfigMixin):
132
    """
133
134
    `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
    non-Markovian guidance.
135

136
137
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
138
139

    Args:
140
141
142
143
144
145
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
146
147
148
        beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one
            of `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
149
150
151
152
153
154
155
156
157
158
159
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        clip_sample (`bool`, defaults to `True`):
            Clip the predicted sample for numerical stability.
        clip_sample_range (`float`, defaults to 1.0):
            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
        set_alpha_to_one (`bool`, defaults to `True`):
            Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
            there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the alpha value at step 0.
        steps_offset (`int`, defaults to 0):
160
            An offset added to the inference steps, as required by some model families.
161
162
163
164
        prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`):
            Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
            process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
            Video](https://huggingface.co/papers/2210.02303) paper).
165
166
167
168
169
170
171
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
172
173
174
175
        timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`):
            The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to
            Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are
            Flawed](https://huggingface.co/papers/2305.08891) for more information.
176
177
178
        rescale_betas_zero_snr (`bool`, defaults to `False`):
            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
            dark samples instead of limiting it to samples with medium brightness. Loosely related to
179
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
180
181
    """

Kashif Rasul's avatar
Kashif Rasul committed
182
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
183
    order = 1
184

185
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
186
187
    def __init__(
        self,
188
189
190
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
191
        beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
192
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
193
194
        clip_sample: bool = True,
        set_alpha_to_one: bool = True,
195
        steps_offset: int = 0,
196
        prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
197
198
199
200
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        clip_sample_range: float = 1.0,
        sample_max_value: float = 1.0,
201
        timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
202
        rescale_betas_zero_snr: bool = False,
Patrick von Platen's avatar
Patrick von Platen committed
203
    ):
204
        if trained_betas is not None:
205
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
206
        elif beta_schedule == "linear":
207
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
208
209
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
210
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
Patrick von Platen's avatar
Patrick von Platen committed
211
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
212
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
213
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
214
        else:
215
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
Patrick von Platen's avatar
Patrick von Platen committed
216

217
218
219
220
        # Rescale for zero SNR
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

221
        self.alphas = 1.0 - self.betas
222
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
223
224
225

        # At every step in ddim, we are looking into the previous alphas_cumprod
        # For the final step, there is no previous alphas_cumprod because we are already at 0
226
        # `set_alpha_to_one` decides whether we set this parameter simply to one or
227
        # whether we use the final alpha of the "non-previous" one.
228
        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
Patrick von Platen's avatar
Patrick von Platen committed
229

230
231
232
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

233
        # setable values
234
        self.num_inference_steps = None
235
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
Patrick von Platen's avatar
Patrick von Platen committed
236

237
    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
238
239
240
241
242
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
243
            sample (`torch.Tensor`):
244
245
246
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
247
248

        Returns:
249
            `torch.Tensor`:
250
                A scaled input sample.
251
252
253
        """
        return sample

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
        """
        Computes the variance of the noise added at a given diffusion step.

        For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
        literature:
            var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
        where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.

        Args:
            timestep (`int`):
                The current timestep in the diffusion process.
            prev_timestep (`int`):
                The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.

        Returns:
            `torch.Tensor`:
                The variance for the current timestep.
        """
273
        alpha_prod_t = self.alphas_cumprod[timestep]
274
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
275
276
277
278
279
280
281
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

282
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
283
    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
284
        """
285
286
        Apply dynamic thresholding to the predicted sample.

287
288
289
290
291
292
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

Quentin Gallouédec's avatar
Quentin Gallouédec committed
293
        https://huggingface.co/papers/2205.11487
294
295
296
297
298
299
300
301

        Args:
            sample (`torch.Tensor`):
                The predicted sample to be thresholded.

        Returns:
            `torch.Tensor`:
                The thresholded sample.
302
303
        """
        dtype = sample.dtype
304
        batch_size, channels, *remaining_dims = sample.shape
305
306
307
308
309

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
310
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
311
312
313
314
315
316
317
318
319
320

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

321
        sample = sample.reshape(batch_size, channels, *remaining_dims)
322
323
324
        sample = sample.to(dtype)

        return sample
325

326
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
327
        """
328
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
329
330
331

        Args:
            num_inference_steps (`int`):
332
                The number of diffusion steps used when generating samples with a pre-trained model.
333
334
335
336
337
            device (`Union[str, torch.device]`, *optional*):
                The device to use for the timesteps.

        Raises:
            ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
338
        """
339
340
341
342
343
344
345
346

        if num_inference_steps > self.config.num_train_timesteps:
            raise ValueError(
                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
                f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
                f" maximal {self.config.num_train_timesteps} timesteps."
            )

347
        self.num_inference_steps = num_inference_steps
348

Quentin Gallouédec's avatar
Quentin Gallouédec committed
349
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
350
351
352
353
354
355
356
357
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
                .round()[::-1]
                .copy()
                .astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
            )

374
        self.timesteps = torch.from_numpy(timesteps).to(device)
375
376
377

    def step(
        self,
378
        model_output: torch.Tensor,
379
        timestep: int,
380
        sample: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
381
382
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
383
        generator: Optional[torch.Generator] = None,
384
        variance_noise: Optional[torch.Tensor] = None,
385
        return_dict: bool = True,
386
    ) -> Union[DDIMSchedulerOutput, Tuple]:
387
        """
388
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
389
390
391
        process from the learned model outputs (most often the predicted noise).

        Args:
392
            model_output (`torch.Tensor`):
393
                The direct output from learned diffusion model.
394
            timestep (`int`):
395
                The current discrete timestep in the diffusion chain.
396
            sample (`torch.Tensor`):
397
                A current instance of a sample created by the diffusion process.
398
399
400
401
            eta (`float`, *optional*, defaults to 0.0):
                The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic)
                and 1 corresponds to DDPM (fully stochastic).
            use_clipped_model_output (`bool`, *optional*, defaults to `False`):
402
403
404
405
406
                If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
                because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
                clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
                `use_clipped_model_output` has no effect.
            generator (`torch.Generator`, *optional*):
407
408
                A random number generator for reproducible sampling.
            variance_noise (`torch.Tensor`, *optional*):
409
410
411
412
                Alternative to generating noise with `generator` by directly providing the noise for the variance
                itself. Useful for methods such as [`CycleDiffusion`].
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
413
414

        Returns:
415
            [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
416
417
                If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
418
419

        """
420
421
422
423
424
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Quentin Gallouédec's avatar
Quentin Gallouédec committed
425
        # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
Patrick von Platen's avatar
Patrick von Platen committed
426
427
428
429
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
430
        # - pred_original_sample -> f_theta(x_t, t) or x_0
Patrick von Platen's avatar
Patrick von Platen committed
431
432
        # - std_dev_t -> sigma_t
        # - eta -> η
433
        # - pred_sample_direction -> "direction pointing to x_t"
434
        # - pred_prev_sample -> "x_t-1"
Patrick von Platen's avatar
Patrick von Platen committed
435

436
        # 1. get previous step value (=t-1)
Nathan Lambert's avatar
Nathan Lambert committed
437
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
438
439

        # 2. compute alphas, betas
440
        alpha_prod_t = self.alphas_cumprod[timestep]
441
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
442

Patrick von Platen's avatar
Patrick von Platen committed
443
444
        beta_prod_t = 1 - alpha_prod_t

445
        # 3. compute predicted original sample from predicted noise also called
Quentin Gallouédec's avatar
Quentin Gallouédec committed
446
        # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
447
        if self.config.prediction_type == "epsilon":
Suraj Patil's avatar
Suraj Patil committed
448
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
449
            pred_epsilon = model_output
450
        elif self.config.prediction_type == "sample":
Suraj Patil's avatar
Suraj Patil committed
451
            pred_original_sample = model_output
452
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
453
        elif self.config.prediction_type == "v_prediction":
Suraj Patil's avatar
Suraj Patil committed
454
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
455
            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
Suraj Patil's avatar
Suraj Patil committed
456
457
        else:
            raise ValueError(
458
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
Suraj Patil's avatar
Suraj Patil committed
459
460
                " `v_prediction`"
            )
Patrick von Platen's avatar
Patrick von Platen committed
461

462
        # 4. Clip or threshold "predicted x_0"
463
464
465
        if self.config.thresholding:
            pred_original_sample = self._threshold_sample(pred_original_sample)
        elif self.config.clip_sample:
466
467
468
469
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )

Patrick von Platen's avatar
Patrick von Platen committed
470
471
        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
472
        variance = self._get_variance(timestep, prev_timestep)
Patrick von Platen's avatar
Patrick von Platen committed
473
        std_dev_t = eta * variance ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
474

Patrick von Platen's avatar
Patrick von Platen committed
475
        if use_clipped_model_output:
476
477
            # the pred_epsilon is always re-derived from the clipped x_0 in Glide
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
anton-l's avatar
anton-l committed
478

Quentin Gallouédec's avatar
Quentin Gallouédec committed
479
        # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
480
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
Patrick von Platen's avatar
Patrick von Platen committed
481

Quentin Gallouédec's avatar
Quentin Gallouédec committed
482
        # 7. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
483
484
485
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
486
487
488
489
490
491
492
            if variance_noise is not None and generator is not None:
                raise ValueError(
                    "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
                    " `variance_noise` stays `None`."
                )

            if variance_noise is None:
493
                variance_noise = randn_tensor(
494
                    model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
495
                )
496
            variance = std_dev_t * variance_noise
497
498

            prev_sample = prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
499

500
        if not return_dict:
501
502
503
504
            return (
                prev_sample,
                pred_original_sample,
            )
505

506
        return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
Patrick von Platen's avatar
Patrick von Platen committed
507

508
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
509
510
    def add_noise(
        self,
511
512
        original_samples: torch.Tensor,
        noise: torch.Tensor,
513
        timesteps: torch.IntTensor,
514
    ) -> torch.Tensor:
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        """
        Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
        diffusion process).

        Args:
            original_samples (`torch.Tensor`):
                The original samples to which noise will be added.
            noise (`torch.Tensor`):
                The noise to add to the samples.
            timesteps (`torch.IntTensor`):
                The timesteps indicating the noise level for each sample.

        Returns:
            `torch.Tensor`:
                The noisy samples.
        """
531
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
532
533
534
535
        # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
        # for the subsequent add_noise calls
        self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
        alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
536
        timesteps = timesteps.to(original_samples.device)
537

538
        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
539
540
541
542
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

543
        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
544
545
546
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
547
548
549
550

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

551
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
552
    def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        """
        Compute the velocity prediction from the sample and noise according to the velocity formula.

        Args:
            sample (`torch.Tensor`):
                The input sample.
            noise (`torch.Tensor`):
                The noise tensor.
            timesteps (`torch.IntTensor`):
                The timesteps for velocity computation.

        Returns:
            `torch.Tensor`:
                The computed velocity.
        """
568
        # Make sure alphas_cumprod and timestep have same device and dtype as sample
569
570
        self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
        alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
571
572
        timesteps = timesteps.to(sample.device)

573
        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
574
575
576
577
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(sample.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

578
        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
579
580
581
582
583
584
585
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
        return velocity

586
    def __len__(self) -> int:
Nathan Lambert's avatar
Nathan Lambert committed
587
        return self.config.num_train_timesteps