scheduling_heun_discrete.py 26 KB
Newer Older
1
# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
16
from typing import List, Optional, Tuple, Union
17
18
19
20
21

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
22
from ..utils import is_scipy_available
Kashif Rasul's avatar
Kashif Rasul committed
23
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
24
25


26
27
28
29
if is_scipy_available():
    import scipy.stats


30
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
31
32
33
34
35
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
36
37
38
39
40
41
42
43
44
45
46
47
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
48
49
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
50
51
52
53

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
YiYi Xu's avatar
YiYi Xu committed
54
    if alpha_transform_type == "cosine":
55

YiYi Xu's avatar
YiYi Xu committed
56
57
58
59
60
61
62
63
64
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
65
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
66
67
68
69
70

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
71
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
72
73
74
    return torch.tensor(betas, dtype=torch.float32)


75
76
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
    """
77
    Scheduler with Heun steps for discrete beta schedules.
78

79
80
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
81
82

    Args:
83
84
85
86
87
88
89
90
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
91
            `linear` or `scaled_linear`.
92
93
94
95
96
97
98
99
100
101
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        clip_sample (`bool`, defaults to `True`):
            Clip the predicted sample for numerical stability.
        clip_sample_range (`float`, defaults to 1.0):
            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
102
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
103
104
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
105
106
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
107
108
109
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
110
111
112
113
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
114
            An offset added to the inference steps, as required by some model families.
115
116
    """

Kashif Rasul's avatar
Kashif Rasul committed
117
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
118
119
120
121
122
123
124
125
126
    order = 2

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.00085,  # sensible defaults
        beta_end: float = 0.012,
        beta_schedule: str = "linear",
127
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
128
        prediction_type: str = "epsilon",
129
        use_karras_sigmas: Optional[bool] = False,
130
        use_exponential_sigmas: Optional[bool] = False,
131
        use_beta_sigmas: Optional[bool] = False,
YiYi Xu's avatar
YiYi Xu committed
132
133
        clip_sample: Optional[bool] = False,
        clip_sample_range: float = 1.0,
134
135
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
136
    ):
137
138
139
140
141
142
        if self.config.use_beta_sigmas and not is_scipy_available():
            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
143
        if trained_betas is not None:
144
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
145
146
147
148
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
149
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
150
151
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
YiYi Xu's avatar
YiYi Xu committed
152
153
154
            self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
        elif beta_schedule == "exp":
            self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
155
        else:
156
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
157
158
159
160
161
162

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        #  set all values
        self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
163
        self.use_karras_sigmas = use_karras_sigmas
164

YiYi Xu's avatar
YiYi Xu committed
165
        self._step_index = None
166
        self._begin_index = None
167
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
YiYi Xu's avatar
YiYi Xu committed
168

169
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
170
171
172
173
174
175
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps

        indices = (schedule_timesteps == timestep).nonzero()

YiYi Xu's avatar
YiYi Xu committed
176
177
178
179
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
180
        pos = 1 if len(indices) > 1 else 0
YiYi Xu's avatar
YiYi Xu committed
181

182
183
        return indices[pos].item()

184
185
186
187
188
189
190
191
    @property
    def init_noise_sigma(self):
        # standard deviation of the initial noise distribution
        if self.config.timestep_spacing in ["linspace", "trailing"]:
            return self.sigmas.max()

        return (self.sigmas.max() ** 2 + 1) ** 0.5

YiYi Xu's avatar
YiYi Xu committed
192
193
194
    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
195
        The index counter for current timestep. It will increase 1 after each scheduler step.
YiYi Xu's avatar
YiYi Xu committed
196
197
198
        """
        return self._step_index

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

217
218
    def scale_model_input(
        self,
219
220
221
        sample: torch.Tensor,
        timestep: Union[float, torch.Tensor],
    ) -> torch.Tensor:
222
223
224
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.
225
226

        Args:
227
            sample (`torch.Tensor`):
228
229
230
231
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.

232
        Returns:
233
            `torch.Tensor`:
234
                A scaled input sample.
235
        """
YiYi Xu's avatar
YiYi Xu committed
236
237
        if self.step_index is None:
            self._init_step_index(timestep)
238

YiYi Xu's avatar
YiYi Xu committed
239
        sigma = self.sigmas[self.step_index]
240
241
242
243
244
        sample = sample / ((sigma**2 + 1) ** 0.5)
        return sample

    def set_timesteps(
        self,
245
        num_inference_steps: Optional[int] = None,
246
247
        device: Union[str, torch.device] = None,
        num_train_timesteps: Optional[int] = None,
248
        timesteps: Optional[List[int]] = None,
249
250
    ):
        """
251
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
252
253
254

        Args:
            num_inference_steps (`int`):
255
256
257
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
258
259
260
261
262
263
264
            num_train_timesteps (`int`, *optional*):
                The number of diffusion steps used when training the model. If `None`, the default
                `num_train_timesteps` attribute is used.
            timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
                generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
                must be `None`, and `timestep_spacing` attribute will be ignored.
265
        """
266
267
268
269
270
271
        if num_inference_steps is None and timesteps is None:
            raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
        if num_inference_steps is not None and timesteps is not None:
            raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
        if timesteps is not None and self.config.use_karras_sigmas:
            raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
272
273
        if timesteps is not None and self.config.use_exponential_sigmas:
            raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
274
275
        if timesteps is not None and self.config.use_beta_sigmas:
            raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
276
277

        num_inference_steps = num_inference_steps or len(timesteps)
278
279
280
        self.num_inference_steps = num_inference_steps
        num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps

281
282
        if timesteps is not None:
            timesteps = np.array(timesteps, dtype=np.float32)
283
        else:
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
            if self.config.timestep_spacing == "linspace":
                timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
            elif self.config.timestep_spacing == "leading":
                step_ratio = num_train_timesteps // self.num_inference_steps
                # creates integer timesteps by multiplying by ratio
                # casting to int to avoid issues when num_inference_step is power of 3
                timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
                timesteps += self.config.steps_offset
            elif self.config.timestep_spacing == "trailing":
                step_ratio = num_train_timesteps / self.num_inference_steps
                # creates integer timesteps by multiplying by ratio
                # casting to int to avoid issues when num_inference_step is power of 3
                timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
                timesteps -= 1
            else:
                raise ValueError(
                    f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
                )
303
304

        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
305
        log_sigmas = np.log(sigmas)
306
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
307

YiYi Xu's avatar
YiYi Xu committed
308
        if self.config.use_karras_sigmas:
309
310
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
311
312
313
        elif self.config.use_exponential_sigmas:
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
314
315
316
        elif self.config.use_beta_sigmas:
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
317

318
319
320
321
322
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
        sigmas = torch.from_numpy(sigmas).to(device=device)
        self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])

        timesteps = torch.from_numpy(timesteps)
323
        timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
324

YiYi Xu's avatar
YiYi Xu committed
325
        self.timesteps = timesteps.to(device=device)
326
327
328
329
330

        # empty dt and derivative
        self.prev_derivative = None
        self.dt = None

YiYi Xu's avatar
YiYi Xu committed
331
        self._step_index = None
332
        self._begin_index = None
333
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
YiYi Xu's avatar
YiYi Xu committed
334

335
336
337
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
338
        log_sigma = np.log(np.maximum(sigma, 1e-10))
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
360
    def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
361
362
        """Constructs the noise schedule of Karras et al. (2022)."""

Suraj Patil's avatar
Suraj Patil committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
377
378
379
380
381
382
383
384

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
        """Constructs an exponential noise schedule."""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

        sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
        return sigmas

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
        """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

        sigmas = torch.Tensor(
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

439
440
441
442
    @property
    def state_in_first_order(self):
        return self.dt is None

YiYi Xu's avatar
YiYi Xu committed
443
444
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
    def _init_step_index(self, timestep):
445
446
447
448
        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
YiYi Xu's avatar
YiYi Xu committed
449
        else:
450
            self._step_index = self._begin_index
YiYi Xu's avatar
YiYi Xu committed
451

452
453
    def step(
        self,
454
455
456
        model_output: Union[torch.Tensor, np.ndarray],
        timestep: Union[float, torch.Tensor],
        sample: Union[torch.Tensor, np.ndarray],
457
458
459
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
460
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
461
        process from the learned model outputs (most often the predicted noise).
462
463

        Args:
464
            model_output (`torch.Tensor`):
465
466
467
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
468
            sample (`torch.Tensor`):
469
470
471
472
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.

473
474
        Returns:
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
475
476
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
477
        """
YiYi Xu's avatar
YiYi Xu committed
478
479
        if self.step_index is None:
            self._init_step_index(timestep)
480
481

        if self.state_in_first_order:
YiYi Xu's avatar
YiYi Xu committed
482
483
            sigma = self.sigmas[self.step_index]
            sigma_next = self.sigmas[self.step_index + 1]
484
485
        else:
            # 2nd order / Heun's method
YiYi Xu's avatar
YiYi Xu committed
486
487
            sigma = self.sigmas[self.step_index - 1]
            sigma_next = self.sigmas[self.step_index]
488
489
490
491
492
493
494
495

        # currently only gamma=0 is supported. This usually works best anyways.
        # We can support gamma in the future but then need to scale the timestep before
        # passing it to the model which requires a change in API
        gamma = 0
        sigma_hat = sigma * (gamma + 1)  # Note: sigma_hat == sigma for now

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
496
        if self.config.prediction_type == "epsilon":
Suraj Patil's avatar
Suraj Patil committed
497
498
            sigma_input = sigma_hat if self.state_in_first_order else sigma_next
            pred_original_sample = sample - sigma_input * model_output
499
        elif self.config.prediction_type == "v_prediction":
Suraj Patil's avatar
Suraj Patil committed
500
501
502
503
            sigma_input = sigma_hat if self.state_in_first_order else sigma_next
            pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
                sample / (sigma_input**2 + 1)
            )
504
        elif self.config.prediction_type == "sample":
YiYi Xu's avatar
YiYi Xu committed
505
            pred_original_sample = model_output
506
507
508
509
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )
510

YiYi Xu's avatar
YiYi Xu committed
511
512
513
514
515
        if self.config.clip_sample:
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )

516
        if self.state_in_first_order:
517
            # 2. Convert to an ODE derivative for 1st order
518
            derivative = (sample - pred_original_sample) / sigma_hat
519
            # 3. delta timestep
520
521
522
523
524
525
526
527
            dt = sigma_next - sigma_hat

            # store for 2nd order step
            self.prev_derivative = derivative
            self.dt = dt
            self.sample = sample
        else:
            # 2. 2nd order / Heun's method
Suraj Patil's avatar
Suraj Patil committed
528
            derivative = (sample - pred_original_sample) / sigma_next
529
530
            derivative = (self.prev_derivative + derivative) / 2

531
            # 3. take prev timestep & sample
532
533
534
535
536
537
538
539
540
541
542
            dt = self.dt
            sample = self.sample

            # free dt and derivative
            # Note, this puts the scheduler in "first order mode"
            self.prev_derivative = None
            self.dt = None
            self.sample = None

        prev_sample = sample + derivative * dt

YiYi Xu's avatar
YiYi Xu committed
543
544
545
        # upon completion increase step index by one
        self._step_index += 1

546
547
548
549
550
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

551
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
552
553
    def add_noise(
        self,
554
555
556
557
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
    ) -> torch.Tensor:
558
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
559
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
560
561
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
562
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
563
564
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
565
            schedule_timesteps = self.timesteps.to(original_samples.device)
566
567
            timesteps = timesteps.to(original_samples.device)

568
569
570
        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
571
572
573
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
574
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
575
            # add noise is called before first denoising step to create initial latent(img2img)
576
            step_indices = [self.begin_index] * timesteps.shape[0]
577

578
        sigma = sigmas[step_indices].flatten()
579
580
581
582
583
584
585
586
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps