scheduling_pndm.py 22.2 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 Zhejiang University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

17
import math
18
from typing import List, Literal, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
19

20
import numpy as np
21
import torch
22

23
from ..configuration_utils import ConfigMixin, register_to_config
Kashif Rasul's avatar
Kashif Rasul committed
24
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25
26


27
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
28
def betas_for_alpha_bar(
29
30
31
32
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
33
    """
Patrick von Platen's avatar
Patrick von Platen committed
34
35
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
36

37
38
39
40
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
41
42
43
44
45
46
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
47
48

    Returns:
49
50
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
51
    """
YiYi Xu's avatar
YiYi Xu committed
52
    if alpha_transform_type == "cosine":
53

YiYi Xu's avatar
YiYi Xu committed
54
55
56
57
58
59
60
61
62
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
63
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
64
65
66
67
68

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
69
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
70
    return torch.tensor(betas, dtype=torch.float32)
Patrick von Platen's avatar
Patrick von Platen committed
71
72
73


class PNDMScheduler(SchedulerMixin, ConfigMixin):
74
    """
75
76
    `PNDMScheduler` uses pseudo numerical methods for diffusion models such as the Runge-Kutta and linear multi-step
    method.
77

78
79
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
80
81

    Args:
82
        num_train_timesteps (`int`, defaults to `1000`):
83
            The number of diffusion steps to train the model.
84
        beta_start (`float`, defaults to `0.0001`):
85
            The starting `beta` value of inference.
86
        beta_end (`float`, defaults to `0.02`):
87
            The final `beta` value.
88
89
        beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
90
91
92
93
94
95
96
97
98
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        skip_prk_steps (`bool`, defaults to `False`):
            Allows the scheduler to skip the Runge-Kutta steps defined in the original paper as being required before
            PLMS steps.
        set_alpha_to_one (`bool`, defaults to `False`):
            Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
            there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the alpha value at step 0.
99
        prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`):
100
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
101
102
            or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
        timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
103
104
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
105
        steps_offset (`int`, defaults to `0`):
106
            An offset added to the inference steps, as required by some model families.
107
108
    """

Kashif Rasul's avatar
Kashif Rasul committed
109
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
110
    order = 1
111

112
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
113
114
    def __init__(
        self,
Partho's avatar
Partho committed
115
116
117
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
118
        beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
119
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
Partho's avatar
Partho committed
120
        skip_prk_steps: bool = False,
121
        set_alpha_to_one: bool = False,
122
123
        prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
        timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
124
        steps_offset: int = 0,
Patrick von Platen's avatar
Patrick von Platen committed
125
    ):
126
        if trained_betas is not None:
127
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
128
        elif beta_schedule == "linear":
129
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
130
131
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
132
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
Patrick von Platen's avatar
Patrick von Platen committed
133
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
134
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
135
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
136
        else:
137
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
Patrick von Platen's avatar
Patrick von Platen committed
138
139

        self.alphas = 1.0 - self.betas
140
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
Patrick von Platen's avatar
Patrick von Platen committed
141

142
        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
143

144
145
146
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

Patrick von Platen's avatar
Patrick von Platen committed
147
        # For now we only support F-PNDM, i.e. the runge-kutta method
Quentin Gallouédec's avatar
Quentin Gallouédec committed
148
        # For more information on the algorithm please take a look at the paper: https://huggingface.co/papers/2202.09778
Patrick von Platen's avatar
Patrick von Platen committed
149
        # mainly at formula (9), (12), (13) and the Algorithm 2.
Patrick von Platen's avatar
Patrick von Platen committed
150
151
152
        self.pndm_order = 4

        # running values
Patrick von Platen's avatar
Patrick von Platen committed
153
        self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
154
        self.counter = 0
155
        self.cur_sample = None
Patrick von Platen's avatar
Patrick von Platen committed
156
157
        self.ets = []

158
159
        # setable values
        self.num_inference_steps = None
Patrick von Platen's avatar
Patrick von Platen committed
160
        self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
161
162
        self.prk_timesteps = None
        self.plms_timesteps = None
Patrick von Platen's avatar
Patrick von Platen committed
163
        self.timesteps = None
164

165
    def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
166
        """
167
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
168
169
170

        Args:
            num_inference_steps (`int`):
171
172
173
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
174
        """
175

176
        self.num_inference_steps = num_inference_steps
Quentin Gallouédec's avatar
Quentin Gallouédec committed
177
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        if self.config.timestep_spacing == "linspace":
            self._timesteps = (
                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
            self._timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype(
                np.int64
            )
            self._timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
200
201
202
203
204

        if self.config.skip_prk_steps:
            # for some models like stable diffusion the prk steps can/should be skipped to
            # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
            # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
205
            self.prk_timesteps = np.array([])
206
207
208
            self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
                ::-1
            ].copy()
209
210
211
212
        else:
            prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
                np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
            )
213
214
215
216
            self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
            self.plms_timesteps = self._timesteps[:-3][
                ::-1
            ].copy()  # we copy to avoid having negative strides which are not supported by torch.from_numpy
Patrick von Platen's avatar
Patrick von Platen committed
217

218
219
        timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
        self.timesteps = torch.from_numpy(timesteps).to(device)
Patrick von Platen's avatar
Patrick von Platen committed
220

221
        self.ets = []
Patrick von Platen's avatar
Patrick von Platen committed
222
        self.counter = 0
223
        self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
224

Patrick von Platen's avatar
Patrick von Platen committed
225
226
    def step(
        self,
227
        model_output: torch.Tensor,
Patrick von Platen's avatar
Patrick von Platen committed
228
        timestep: int,
229
        sample: torch.Tensor,
230
231
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
232
        """
233
234
235
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise), and calls [`~PNDMScheduler.step_prk`]
        or [`~PNDMScheduler.step_plms`] depending on the internal variable `counter`.
236
237

        Args:
238
            model_output (`torch.Tensor`):
239
240
241
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
242
            sample (`torch.Tensor`):
243
                A current instance of a sample created by the diffusion process.
244
            return_dict (`bool`, defaults to `True`):
245
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
246

247
        Returns:
248
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
249
250
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
251
252

        """
253
        if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
254
            return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
255
        else:
256
            return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
257

258
259
    def step_prk(
        self,
260
        model_output: torch.Tensor,
261
        timestep: int,
262
        sample: torch.Tensor,
263
264
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
265
        """
266
267
268
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the Runge-Kutta method. It performs four forward passes to approximate the solution to the differential
        equation.
269
270

        Args:
271
            model_output (`torch.Tensor`):
272
273
274
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
275
            sample (`torch.Tensor`):
276
                A current instance of a sample created by the diffusion process.
277
            return_dict (`bool`, defaults to `True`):
278
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
279
280

        Returns:
281
282
283
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
Nathan Lambert's avatar
Nathan Lambert committed
284
        """
285
286
287
288
289
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
290
        diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
291
        prev_timestep = timestep - diff_to_prev
Patrick von Platen's avatar
Patrick von Platen committed
292
        timestep = self.prk_timesteps[self.counter // 4 * 4]
Patrick von Platen's avatar
Patrick von Platen committed
293

Patrick von Platen's avatar
Patrick von Platen committed
294
        if self.counter % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
295
296
            self.cur_model_output += 1 / 6 * model_output
            self.ets.append(model_output)
297
            self.cur_sample = sample
Patrick von Platen's avatar
Patrick von Platen committed
298
        elif (self.counter - 1) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
299
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
300
        elif (self.counter - 2) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
301
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
302
        elif (self.counter - 3) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
303
304
            model_output = self.cur_model_output + 1 / 6 * model_output
            self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
305

Patrick von Platen's avatar
Patrick von Platen committed
306
307
308
        # cur_sample should not be `None`
        cur_sample = self.cur_sample if self.cur_sample is not None else sample

Patrick von Platen's avatar
Patrick von Platen committed
309
310
311
        prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
        self.counter += 1

312
313
314
315
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
316

317
318
    def step_plms(
        self,
319
        model_output: torch.Tensor,
320
        timestep: int,
321
        sample: torch.Tensor,
322
323
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
324
        """
325
326
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the linear multistep method. It performs one forward pass multiple times to approximate the solution.
327
328

        Args:
329
            model_output (`torch.Tensor`):
330
331
332
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
333
            sample (`torch.Tensor`):
334
                A current instance of a sample created by the diffusion process.
335
            return_dict (`bool`, defaults to `True`):
336
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
337
338

        Returns:
339
340
341
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
Nathan Lambert's avatar
Nathan Lambert committed
342
        """
343
344
345
346
347
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

348
        if not self.config.skip_prk_steps and len(self.ets) < 3:
Patrick von Platen's avatar
Patrick von Platen committed
349
350
351
352
353
354
355
            raise ValueError(
                f"{self.__class__} can only be run AFTER scheduler has been run "
                "in 'prk' mode for at least 12 iterations "
                "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
                "for more information."
            )

356
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
357

358
        if self.counter != 1:
359
            self.ets = self.ets[-3:]
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
            self.ets.append(model_output)
        else:
            prev_timestep = timestep
            timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

        if len(self.ets) == 1 and self.counter == 0:
            model_output = model_output
            self.cur_sample = sample
        elif len(self.ets) == 1 and self.counter == 1:
            model_output = (model_output + self.ets[-1]) / 2
            sample = self.cur_sample
            self.cur_sample = None
        elif len(self.ets) == 2:
            model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
        elif len(self.ets) == 3:
            model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
        else:
            model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
Patrick von Platen's avatar
Patrick von Platen committed
378

Patrick von Platen's avatar
Patrick von Platen committed
379
380
381
        prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
        self.counter += 1

382
383
384
385
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
386

387
    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
388
389
390
391
392
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
393
            sample (`torch.Tensor`):
394
                The input sample.
395
396

        Returns:
397
            `torch.Tensor`:
398
                A scaled input sample.
399
400
401
        """
        return sample

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    def _get_prev_sample(
        self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM
        paper](https://huggingface.co/papers/2202.09778).

        Args:
            sample (`torch.Tensor`):
                The current sample x_t.
            timestep (`int`):
                The current timestep t.
            prev_timestep (`int`):
                The previous timestep (t-δ).
            model_output (`torch.Tensor`):
                The model output e_θ(x_t, t).

        Returns:
            `torch.Tensor`:
                The previous sample x_(t-δ).
        """
423
424
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
425
426
427
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

428
429
430
431
432
433
434
        if self.config.prediction_type == "v_prediction":
            model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
        elif self.config.prediction_type != "epsilon":
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
            )

Patrick von Platen's avatar
Patrick von Platen committed
435
436
437
438
439
440
441
        # corresponds to (α_(t−δ) - α_t) divided by
        # denominator of x_t in formula (9) and plus 1
        # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
        # sqrt(α_(t−δ)) / sqrt(α_t))
        sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)

        # corresponds to denominator of e_θ(x_t, t) in formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
442
        model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
Patrick von Platen's avatar
Patrick von Platen committed
443
444
445
446
            alpha_prod_t * beta_prod_t * alpha_prod_t_prev
        ) ** (0.5)

        # full formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
447
448
449
        prev_sample = (
            sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
        )
Patrick von Platen's avatar
Patrick von Platen committed
450
451

        return prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
452

453
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
Partho's avatar
Partho committed
454
455
    def add_noise(
        self,
456
457
        original_samples: torch.Tensor,
        noise: torch.Tensor,
458
        timesteps: torch.IntTensor,
459
    ) -> torch.Tensor:
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        """
        Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
        diffusion process).

        Args:
            original_samples (`torch.Tensor`):
                The original samples to which noise will be added.
            noise (`torch.Tensor`):
                The noise to add to the samples.
            timesteps (`torch.IntTensor`):
                The timesteps indicating the noise level for each sample.

        Returns:
            `torch.Tensor`:
                The noisy samples.
        """
476
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
477
478
479
480
        # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
        # for the subsequent add_noise calls
        self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
        alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
481
        timesteps = timesteps.to(original_samples.device)
482

483
        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
484
485
486
487
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

488
        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
489
490
491
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
492
493
494
495

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

496
    def __len__(self) -> int:
Nathan Lambert's avatar
Nathan Lambert committed
497
        return self.config.num_train_timesteps