scheduling_ddim.py 24.4 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16
17

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

Patrick von Platen's avatar
Patrick von Platen committed
18
import math
19
from dataclasses import dataclass
20
from typing import List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
21

Patrick von Platen's avatar
Patrick von Platen committed
22
import numpy as np
23
import torch
Patrick von Platen's avatar
Patrick von Platen committed
24

25
from ..configuration_utils import ConfigMixin, register_to_config
26
from ..utils import BaseOutput, randn_tensor
Kashif Rasul's avatar
Kashif Rasul committed
27
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
28
29
30


@dataclass
31
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class DDIMSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
47
48


49
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
50
51
52
53
54
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
55
    """
Patrick von Platen's avatar
Patrick von Platen committed
56
57
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
58

59
60
61
62
63
64
65
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
66
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
67
68
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
69
70
71

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
72
    """
YiYi Xu's avatar
YiYi Xu committed
73
    if alpha_transform_type == "cosine":
74

YiYi Xu's avatar
YiYi Xu committed
75
76
77
78
79
80
81
82
83
84
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
        raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
85

86
87
88
89
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
90
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
91
    return torch.tensor(betas, dtype=torch.float32)
Patrick von Platen's avatar
Patrick von Platen committed
92
93


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def rescale_zero_terminal_snr(betas):
    """
    Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)


    Args:
        betas (`torch.FloatTensor`):
            the betas that the scheduler is being initialized with.

    Returns:
        `torch.FloatTensor`: rescaled betas with zero terminal SNR
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


Patrick von Platen's avatar
Patrick von Platen committed
130
class DDIMScheduler(SchedulerMixin, ConfigMixin):
131
132
133
134
    """
    Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
    diffusion probabilistic models (DDPMs) with non-Markovian guidance.

135
136
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
137
138
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.
139

140
141
142
143
144
145
146
147
148
    For more details, see the original paper: https://arxiv.org/abs/2010.02502

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
149
150
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
151
        clip_sample (`bool`, default `True`):
152
153
154
            option to clip predicted sample for numerical stability.
        clip_sample_range (`float`, default `1.0`):
            the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
155
        set_alpha_to_one (`bool`, default `True`):
156
157
158
159
160
161
162
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        steps_offset (`int`, default `0`):
            an offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
163
164
165
166
        prediction_type (`str`, default `epsilon`, optional):
            prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
            process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
            https://imagen.research.google/video/paper.pdf)
167
168
169
170
171
172
173
174
175
        thresholding (`bool`, default `False`):
            whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
            Note that the thresholding method is unsuitable for latent-space diffusion models (such as
            stable-diffusion).
        dynamic_thresholding_ratio (`float`, default `0.995`):
            the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
            (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
        sample_max_value (`float`, default `1.0`):
            the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
176
177
178
179
180
181
182
183
        timestep_spacing (`str`, default `"leading"`):
            The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
            Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
        rescale_betas_zero_snr (`bool`, default `False`):
            whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf).
            This can enable the model to generate very bright and dark samples instead of limiting it to samples with
            medium brightness. Loosely related to
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
184
185
    """

Kashif Rasul's avatar
Kashif Rasul committed
186
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
187
    order = 1
188

189
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
190
191
    def __init__(
        self,
192
193
194
195
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
196
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
197
198
        clip_sample: bool = True,
        set_alpha_to_one: bool = True,
199
        steps_offset: int = 0,
Suraj Patil's avatar
Suraj Patil committed
200
        prediction_type: str = "epsilon",
201
202
203
204
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        clip_sample_range: float = 1.0,
        sample_max_value: float = 1.0,
205
206
        timestep_spacing: str = "leading",
        rescale_betas_zero_snr: bool = False,
Patrick von Platen's avatar
Patrick von Platen committed
207
    ):
208
        if trained_betas is not None:
209
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
210
        elif beta_schedule == "linear":
211
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
212
213
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
214
215
216
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
Patrick von Platen's avatar
Patrick von Platen committed
217
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
218
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
219
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
220
221
222
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

223
224
225
226
        # Rescale for zero SNR
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

227
        self.alphas = 1.0 - self.betas
228
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
229
230
231

        # At every step in ddim, we are looking into the previous alphas_cumprod
        # For the final step, there is no previous alphas_cumprod because we are already at 0
232
        # `set_alpha_to_one` decides whether we set this parameter simply to one or
233
        # whether we use the final alpha of the "non-previous" one.
234
        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
Patrick von Platen's avatar
Patrick von Platen committed
235

236
237
238
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

239
        # setable values
240
        self.num_inference_steps = None
241
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
Patrick von Platen's avatar
Patrick von Platen committed
242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

257
258
    def _get_variance(self, timestep, prev_timestep):
        alpha_prod_t = self.alphas_cumprod[timestep]
259
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
260
261
262
263
264
265
266
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

267
268
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
    def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        """
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://arxiv.org/abs/2205.11487
        """
        dtype = sample.dtype
        batch_size, channels, height, width = sample.shape

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
        sample = sample.reshape(batch_size, channels * height * width)

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]

        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

        sample = sample.reshape(batch_size, channels, height, width)
        sample = sample.to(dtype)

        return sample
301

302
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
303
304
305
306
307
308
309
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
310
311
312
313
314
315
316
317

        if num_inference_steps > self.config.num_train_timesteps:
            raise ValueError(
                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
                f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
                f" maximal {self.config.num_train_timesteps} timesteps."
            )

318
        self.num_inference_steps = num_inference_steps
319

320
321
322
323
324
325
326
327
328
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
                .round()[::-1]
                .copy()
                .astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
            )

345
        self.timesteps = torch.from_numpy(timesteps).to(device)
346
347
348

    def step(
        self,
349
        model_output: torch.FloatTensor,
350
        timestep: int,
351
        sample: torch.FloatTensor,
Patrick von Platen's avatar
Patrick von Platen committed
352
353
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
354
        generator=None,
355
        variance_noise: Optional[torch.FloatTensor] = None,
356
        return_dict: bool = True,
357
    ) -> Union[DDIMSchedulerOutput, Tuple]:
358
359
360
361
362
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
363
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
364
            timestep (`int`): current discrete timestep in the diffusion chain.
365
            sample (`torch.FloatTensor`):
366
367
                current instance of sample being created by diffusion process.
            eta (`float`): weight of noise for added noise in diffusion step.
368
369
370
371
            use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
                predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
                `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
                coincide with the one provided as input and `use_clipped_model_output` will have not effect.
372
            generator: random number generator.
373
374
375
            variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
                can directly provide the noise for the variance itself. This is useful for methods such as
                CycleDiffusion. (https://arxiv.org/abs/2210.05559)
376
            return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
377
378

        Returns:
379
380
            [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
381
            returning a tuple, the first element is the sample tensor.
382
383

        """
384
385
386
387
388
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
389
390
391
392
393
        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
394
        # - pred_original_sample -> f_theta(x_t, t) or x_0
Patrick von Platen's avatar
Patrick von Platen committed
395
396
        # - std_dev_t -> sigma_t
        # - eta -> η
397
        # - pred_sample_direction -> "direction pointing to x_t"
398
        # - pred_prev_sample -> "x_t-1"
Patrick von Platen's avatar
Patrick von Platen committed
399

400
        # 1. get previous step value (=t-1)
Nathan Lambert's avatar
Nathan Lambert committed
401
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
402
403

        # 2. compute alphas, betas
404
        alpha_prod_t = self.alphas_cumprod[timestep]
405
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
406

Patrick von Platen's avatar
Patrick von Platen committed
407
408
        beta_prod_t = 1 - alpha_prod_t

409
        # 3. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
410
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
411
        if self.config.prediction_type == "epsilon":
Suraj Patil's avatar
Suraj Patil committed
412
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
413
            pred_epsilon = model_output
414
        elif self.config.prediction_type == "sample":
Suraj Patil's avatar
Suraj Patil committed
415
            pred_original_sample = model_output
416
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
417
        elif self.config.prediction_type == "v_prediction":
Suraj Patil's avatar
Suraj Patil committed
418
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
419
            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
Suraj Patil's avatar
Suraj Patil committed
420
421
        else:
            raise ValueError(
422
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
Suraj Patil's avatar
Suraj Patil committed
423
424
                " `v_prediction`"
            )
Patrick von Platen's avatar
Patrick von Platen committed
425

426
        # 4. Clip or threshold "predicted x_0"
427
428
429
        if self.config.thresholding:
            pred_original_sample = self._threshold_sample(pred_original_sample)
        elif self.config.clip_sample:
430
431
432
433
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )

Patrick von Platen's avatar
Patrick von Platen committed
434
435
        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
436
        variance = self._get_variance(timestep, prev_timestep)
Patrick von Platen's avatar
Patrick von Platen committed
437
        std_dev_t = eta * variance ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
438

Patrick von Platen's avatar
Patrick von Platen committed
439
        if use_clipped_model_output:
440
441
            # the pred_epsilon is always re-derived from the clipped x_0 in Glide
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
anton-l's avatar
anton-l committed
442

Patrick von Platen's avatar
Patrick von Platen committed
443
        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
444
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
Patrick von Platen's avatar
Patrick von Platen committed
445
446

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
447
448
449
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
450
451
452
453
454
455
456
            if variance_noise is not None and generator is not None:
                raise ValueError(
                    "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
                    " `variance_noise` stays `None`."
                )

            if variance_noise is None:
457
                variance_noise = randn_tensor(
458
                    model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
459
                )
460
            variance = std_dev_t * variance_noise
461
462

            prev_sample = prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
463

464
465
466
        if not return_dict:
            return (prev_sample,)

467
        return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
Patrick von Platen's avatar
Patrick von Platen committed
468

469
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
470
471
    def add_noise(
        self,
472
473
474
475
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
476
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
477
        alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
478
        timesteps = timesteps.to(original_samples.device)
479

480
        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
481
482
483
484
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

485
        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
486
487
488
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
489
490
491
492

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

493
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
494
495
496
497
    def get_velocity(
        self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
    ) -> torch.FloatTensor:
        # Make sure alphas_cumprod and timestep have same device and dtype as sample
498
        alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
499
500
        timesteps = timesteps.to(sample.device)

501
        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
502
503
504
505
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(sample.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

506
        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
507
508
509
510
511
512
513
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
        return velocity

Patrick von Platen's avatar
Patrick von Platen committed
514
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
515
        return self.config.num_train_timesteps