scheduling_heun_discrete.py 27.2 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
16
from dataclasses import dataclass
17
from typing import List, Literal, Optional, Tuple, Union
18
19
20
21
22

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
23
24
from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
25
26


27
28
29
30
if is_scipy_available():
    import scipy.stats


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete
class HeunDiscreteSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's `step` function output.

    Args:
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None


50
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
51
def betas_for_alpha_bar(
52
53
54
55
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
56
57
58
59
60
61
62
63
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
64
65
66
67
68
69
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
70
71

    Returns:
72
73
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
74
    """
YiYi Xu's avatar
YiYi Xu committed
75
    if alpha_transform_type == "cosine":
76

YiYi Xu's avatar
YiYi Xu committed
77
78
79
80
81
82
83
84
85
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
86
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
87
88
89
90
91

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
92
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
94
95
    return torch.tensor(betas, dtype=torch.float32)


96
97
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
    """
98
    Scheduler with Heun steps for discrete beta schedules.
99

100
101
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
102
103

    Args:
104
105
106
107
108
109
110
111
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
112
            `linear` or `scaled_linear`.
113
114
115
116
117
118
119
120
121
122
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        clip_sample (`bool`, defaults to `True`):
            Clip the predicted sample for numerical stability.
        clip_sample_range (`float`, defaults to 1.0):
            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
123
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
124
125
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
126
127
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
128
129
130
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
131
132
133
134
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
135
            An offset added to the inference steps, as required by some model families.
136
137
    """

Kashif Rasul's avatar
Kashif Rasul committed
138
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
139
140
141
142
143
144
145
146
147
    order = 2

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.00085,  # sensible defaults
        beta_end: float = 0.012,
        beta_schedule: str = "linear",
148
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
149
        prediction_type: str = "epsilon",
150
        use_karras_sigmas: Optional[bool] = False,
151
        use_exponential_sigmas: Optional[bool] = False,
152
        use_beta_sigmas: Optional[bool] = False,
YiYi Xu's avatar
YiYi Xu committed
153
154
        clip_sample: Optional[bool] = False,
        clip_sample_range: float = 1.0,
155
156
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
157
    ):
158
159
160
161
162
163
        if self.config.use_beta_sigmas and not is_scipy_available():
            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
164
        if trained_betas is not None:
165
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
166
167
168
169
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
170
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
171
172
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
YiYi Xu's avatar
YiYi Xu committed
173
174
175
            self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
        elif beta_schedule == "exp":
            self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
176
        else:
177
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
178
179
180
181
182
183

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        #  set all values
        self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
184
        self.use_karras_sigmas = use_karras_sigmas
185

YiYi Xu's avatar
YiYi Xu committed
186
        self._step_index = None
187
        self._begin_index = None
188
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
YiYi Xu's avatar
YiYi Xu committed
189

190
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
191
192
193
194
195
196
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps

        indices = (schedule_timesteps == timestep).nonzero()

YiYi Xu's avatar
YiYi Xu committed
197
198
199
200
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
201
        pos = 1 if len(indices) > 1 else 0
YiYi Xu's avatar
YiYi Xu committed
202

203
204
        return indices[pos].item()

205
206
207
208
209
210
211
212
    @property
    def init_noise_sigma(self):
        # standard deviation of the initial noise distribution
        if self.config.timestep_spacing in ["linspace", "trailing"]:
            return self.sigmas.max()

        return (self.sigmas.max() ** 2 + 1) ** 0.5

YiYi Xu's avatar
YiYi Xu committed
213
214
215
    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
216
        The index counter for current timestep. It will increase 1 after each scheduler step.
YiYi Xu's avatar
YiYi Xu committed
217
218
219
        """
        return self._step_index

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

238
239
    def scale_model_input(
        self,
240
241
242
        sample: torch.Tensor,
        timestep: Union[float, torch.Tensor],
    ) -> torch.Tensor:
243
244
245
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.
246
247

        Args:
248
            sample (`torch.Tensor`):
249
250
251
252
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.

253
        Returns:
254
            `torch.Tensor`:
255
                A scaled input sample.
256
        """
YiYi Xu's avatar
YiYi Xu committed
257
258
        if self.step_index is None:
            self._init_step_index(timestep)
259

YiYi Xu's avatar
YiYi Xu committed
260
        sigma = self.sigmas[self.step_index]
261
262
263
264
265
        sample = sample / ((sigma**2 + 1) ** 0.5)
        return sample

    def set_timesteps(
        self,
266
        num_inference_steps: Optional[int] = None,
267
268
        device: Union[str, torch.device] = None,
        num_train_timesteps: Optional[int] = None,
269
        timesteps: Optional[List[int]] = None,
270
271
    ):
        """
272
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
273
274
275

        Args:
            num_inference_steps (`int`):
276
277
278
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
279
280
281
282
283
284
285
            num_train_timesteps (`int`, *optional*):
                The number of diffusion steps used when training the model. If `None`, the default
                `num_train_timesteps` attribute is used.
            timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
                generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
                must be `None`, and `timestep_spacing` attribute will be ignored.
286
        """
287
288
289
290
291
292
        if num_inference_steps is None and timesteps is None:
            raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
        if num_inference_steps is not None and timesteps is not None:
            raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
        if timesteps is not None and self.config.use_karras_sigmas:
            raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
293
294
        if timesteps is not None and self.config.use_exponential_sigmas:
            raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
295
296
        if timesteps is not None and self.config.use_beta_sigmas:
            raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
297
298

        num_inference_steps = num_inference_steps or len(timesteps)
299
300
301
        self.num_inference_steps = num_inference_steps
        num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps

302
303
        if timesteps is not None:
            timesteps = np.array(timesteps, dtype=np.float32)
304
        else:
Quentin Gallouédec's avatar
Quentin Gallouédec committed
305
            # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
            if self.config.timestep_spacing == "linspace":
                timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
            elif self.config.timestep_spacing == "leading":
                step_ratio = num_train_timesteps // self.num_inference_steps
                # creates integer timesteps by multiplying by ratio
                # casting to int to avoid issues when num_inference_step is power of 3
                timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
                timesteps += self.config.steps_offset
            elif self.config.timestep_spacing == "trailing":
                step_ratio = num_train_timesteps / self.num_inference_steps
                # creates integer timesteps by multiplying by ratio
                # casting to int to avoid issues when num_inference_step is power of 3
                timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
                timesteps -= 1
            else:
                raise ValueError(
                    f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
                )
324
325

        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
326
        log_sigmas = np.log(sigmas)
327
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
328

YiYi Xu's avatar
YiYi Xu committed
329
        if self.config.use_karras_sigmas:
330
331
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
332
        elif self.config.use_exponential_sigmas:
333
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
334
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
335
        elif self.config.use_beta_sigmas:
336
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
337
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
338

339
340
341
342
343
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
        sigmas = torch.from_numpy(sigmas).to(device=device)
        self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])

        timesteps = torch.from_numpy(timesteps)
344
        timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
345

hlky's avatar
hlky committed
346
        self.timesteps = timesteps.to(device=device, dtype=torch.float32)
347
348
349
350
351

        # empty dt and derivative
        self.prev_derivative = None
        self.dt = None

YiYi Xu's avatar
YiYi Xu committed
352
        self._step_index = None
353
        self._begin_index = None
354
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
YiYi Xu's avatar
YiYi Xu committed
355

356
357
358
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
359
        log_sigma = np.log(np.maximum(sigma, 1e-10))
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
381
    def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
382
383
        """Constructs the noise schedule of Karras et al. (2022)."""

Suraj Patil's avatar
Suraj Patil committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
398
399
400
401
402
403
404
405

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
        """Constructs an exponential noise schedule."""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

425
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
426
427
        return sigmas

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
        """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

449
        sigmas = np.array(
450
451
452
453
454
455
456
457
458
459
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

460
461
462
463
    @property
    def state_in_first_order(self):
        return self.dt is None

YiYi Xu's avatar
YiYi Xu committed
464
465
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
    def _init_step_index(self, timestep):
466
467
468
469
        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
YiYi Xu's avatar
YiYi Xu committed
470
        else:
471
            self._step_index = self._begin_index
YiYi Xu's avatar
YiYi Xu committed
472

473
474
    def step(
        self,
475
476
477
        model_output: Union[torch.Tensor, np.ndarray],
        timestep: Union[float, torch.Tensor],
        sample: Union[torch.Tensor, np.ndarray],
478
        return_dict: bool = True,
479
    ) -> Union[HeunDiscreteSchedulerOutput, Tuple]:
480
        """
481
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
482
        process from the learned model outputs (most often the predicted noise).
483
484

        Args:
485
            model_output (`torch.Tensor`):
486
487
488
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
489
            sample (`torch.Tensor`):
490
491
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
492
493
                Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or
                tuple.
494

495
        Returns:
496
497
498
            [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is
                returned, otherwise a tuple is returned where the first element is the sample tensor.
499
        """
YiYi Xu's avatar
YiYi Xu committed
500
501
        if self.step_index is None:
            self._init_step_index(timestep)
502
503

        if self.state_in_first_order:
YiYi Xu's avatar
YiYi Xu committed
504
505
            sigma = self.sigmas[self.step_index]
            sigma_next = self.sigmas[self.step_index + 1]
506
507
        else:
            # 2nd order / Heun's method
YiYi Xu's avatar
YiYi Xu committed
508
509
            sigma = self.sigmas[self.step_index - 1]
            sigma_next = self.sigmas[self.step_index]
510
511
512
513
514
515
516
517

        # currently only gamma=0 is supported. This usually works best anyways.
        # We can support gamma in the future but then need to scale the timestep before
        # passing it to the model which requires a change in API
        gamma = 0
        sigma_hat = sigma * (gamma + 1)  # Note: sigma_hat == sigma for now

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
518
        if self.config.prediction_type == "epsilon":
Suraj Patil's avatar
Suraj Patil committed
519
520
            sigma_input = sigma_hat if self.state_in_first_order else sigma_next
            pred_original_sample = sample - sigma_input * model_output
521
        elif self.config.prediction_type == "v_prediction":
Suraj Patil's avatar
Suraj Patil committed
522
523
524
525
            sigma_input = sigma_hat if self.state_in_first_order else sigma_next
            pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
                sample / (sigma_input**2 + 1)
            )
526
        elif self.config.prediction_type == "sample":
YiYi Xu's avatar
YiYi Xu committed
527
            pred_original_sample = model_output
528
529
530
531
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )
532

YiYi Xu's avatar
YiYi Xu committed
533
534
535
536
537
        if self.config.clip_sample:
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )

538
        if self.state_in_first_order:
539
            # 2. Convert to an ODE derivative for 1st order
540
            derivative = (sample - pred_original_sample) / sigma_hat
541
            # 3. delta timestep
542
543
544
545
546
547
548
549
            dt = sigma_next - sigma_hat

            # store for 2nd order step
            self.prev_derivative = derivative
            self.dt = dt
            self.sample = sample
        else:
            # 2. 2nd order / Heun's method
Suraj Patil's avatar
Suraj Patil committed
550
            derivative = (sample - pred_original_sample) / sigma_next
551
552
            derivative = (self.prev_derivative + derivative) / 2

553
            # 3. take prev timestep & sample
554
555
556
557
558
559
560
561
562
563
564
            dt = self.dt
            sample = self.sample

            # free dt and derivative
            # Note, this puts the scheduler in "first order mode"
            self.prev_derivative = None
            self.dt = None
            self.sample = None

        prev_sample = sample + derivative * dt

YiYi Xu's avatar
YiYi Xu committed
565
566
567
        # upon completion increase step index by one
        self._step_index += 1

568
        if not return_dict:
569
570
571
572
            return (
                prev_sample,
                pred_original_sample,
            )
573

574
        return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
575

576
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
577
578
    def add_noise(
        self,
579
580
581
582
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
    ) -> torch.Tensor:
583
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
584
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
585
586
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
587
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
588
589
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
590
            schedule_timesteps = self.timesteps.to(original_samples.device)
591
592
            timesteps = timesteps.to(original_samples.device)

593
594
595
        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
596
597
598
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
599
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
600
            # add noise is called before first denoising step to create initial latent(img2img)
601
            step_indices = [self.begin_index] * timesteps.shape[0]
602

603
        sigma = sigmas[step_indices].flatten()
604
605
606
607
608
609
610
611
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps