scheduling_euler_discrete.py 40.9 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
hlky's avatar
hlky committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
hlky's avatar
hlky committed
16
from dataclasses import dataclass
17
from typing import List, Literal, Optional, Tuple, Union
hlky's avatar
hlky committed
18
19
20
21
22

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
23
from ..utils import BaseOutput, is_scipy_available, logging
Dhruv Nair's avatar
Dhruv Nair committed
24
from ..utils.torch_utils import randn_tensor
Kashif Rasul's avatar
Kashif Rasul committed
25
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
hlky's avatar
hlky committed
26
27


28
29
30
if is_scipy_available():
    import scipy.stats

hlky's avatar
hlky committed
31
32
33
34
35
36
37
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
class EulerDiscreteSchedulerOutput(BaseOutput):
    """
38
    Output class for the scheduler's `step` function output.
hlky's avatar
hlky committed
39
40

    Args:
41
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
hlky's avatar
hlky committed
43
            denoising loop.
44
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
45
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
hlky's avatar
hlky committed
46
47
48
            `pred_original_sample` can be used to preview progress or for guidance.
    """

49
50
    prev_sample: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
hlky's avatar
hlky committed
51
52


53
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
54
def betas_for_alpha_bar(
55
56
57
58
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
59
60
61
62
63
64
65
66
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
67
68
69
70
71
72
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
73
74

    Returns:
75
76
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
77
    """
YiYi Xu's avatar
YiYi Xu committed
78
    if alpha_transform_type == "cosine":
79

YiYi Xu's avatar
YiYi Xu committed
80
81
82
83
84
85
86
87
88
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
89
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
90
91
92
93
94

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
95
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
96
97
98
    return torch.tensor(betas, dtype=torch.float32)


99
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
100
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
101
    """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
102
    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
103
104

    Args:
105
        betas (`torch.Tensor`):
106
            The betas that the scheduler is being initialized with.
107
108

    Returns:
109
110
        `torch.Tensor`:
            Rescaled betas with zero terminal SNR.
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


hlky's avatar
hlky committed
136
137
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
    """
138
    Euler scheduler.
hlky's avatar
hlky committed
139

140
141
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
hlky's avatar
hlky committed
142
143

    Args:
144
145
146
147
148
149
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
150
        beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
151
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
152
            `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
153
154
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
155
156
157
        prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`, *optional*):
            Prediction type of the scheduler function; can be `"epsilon"` (predicts the noise of the diffusion
            process), `"sample"` (directly predicts the noisy sample`) or `"v_prediction"` (see section 2.4 of [Imagen
158
            Video](https://imagen.research.google/video/paper.pdf) paper).
159
160
        interpolation_type (`Literal["linear", "log_linear"]`, defaults to `"linear"`, *optional*):
            The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
161
            `"linear"` or `"log_linear"`.
162
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
163
164
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
165
166
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
167
168
169
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
170
171
172
173
174
175
176
        sigma_min (`float`, *optional*):
            The minimum sigma value for the noise schedule. If not provided, defaults to the last sigma in the
            schedule.
        sigma_max (`float`, *optional*):
            The maximum sigma value for the noise schedule. If not provided, defaults to the first sigma in the
            schedule.
        timestep_spacing (`Literal["linspace", "leading", "trailing"]`, defaults to `"linspace"`):
177
178
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
179
180
        timestep_type (`Literal["discrete", "continuous"]`, defaults to `"discrete"`):
            The type of timesteps to use. Can be `"discrete"` or `"continuous"`.
181
        steps_offset (`int`, defaults to 0):
182
            An offset added to the inference steps, as required by some model families.
183
184
185
186
        rescale_betas_zero_snr (`bool`, defaults to `False`):
            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
            dark samples instead of limiting it to samples with medium brightness. Loosely related to
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
187
        final_sigmas_type (`Literal["zero", "sigma_min"]`, defaults to `"zero"`):
188
            The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
189
            sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
hlky's avatar
hlky committed
190
191
    """

Kashif Rasul's avatar
Kashif Rasul committed
192
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
193
    order = 1
194

hlky's avatar
hlky committed
195
196
197
198
199
200
    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
201
        beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
202
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
203
204
        prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
        interpolation_type: Literal["linear", "log_linear"] = "linear",
205
        use_karras_sigmas: Optional[bool] = False,
206
        use_exponential_sigmas: Optional[bool] = False,
207
        use_beta_sigmas: Optional[bool] = False,
Suraj Patil's avatar
Suraj Patil committed
208
209
        sigma_min: Optional[float] = None,
        sigma_max: Optional[float] = None,
210
211
        timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
        timestep_type: Literal["discrete", "continuous"] = "discrete",
212
        steps_offset: int = 0,
213
        rescale_betas_zero_snr: bool = False,
214
        final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
hlky's avatar
hlky committed
215
    ):
216
217
218
219
220
221
        if self.config.use_beta_sigmas and not is_scipy_available():
            raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
hlky's avatar
hlky committed
222
        if trained_betas is not None:
223
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
hlky's avatar
hlky committed
224
225
226
227
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
228
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
229
230
231
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
hlky's avatar
hlky committed
232
        else:
233
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
hlky's avatar
hlky committed
234

235
236
237
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

hlky's avatar
hlky committed
238
239
240
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

241
242
243
244
245
        if rescale_betas_zero_snr:
            # Close to 0 without being 0 so first sigma is not inf
            # FP16 smallest positive subnormal works well here
            self.alphas_cumprod[-1] = 2**-24

246
        sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
Suraj Patil's avatar
Suraj Patil committed
247
248
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
hlky's avatar
hlky committed
249
250
251

        # setable values
        self.num_inference_steps = None
Suraj Patil's avatar
Suraj Patil committed
252
253
254
255
256
257
258
259
260

        # TODO: Support the full EDM scalings for all prediction types and timestep types
        if timestep_type == "continuous" and prediction_type == "v_prediction":
            self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
        else:
            self.timesteps = timesteps

        self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

hlky's avatar
hlky committed
261
        self.is_scale_input_called = False
262
        self.use_karras_sigmas = use_karras_sigmas
263
        self.use_exponential_sigmas = use_exponential_sigmas
264
        self.use_beta_sigmas = use_beta_sigmas
hlky's avatar
hlky committed
265

YiYi Xu's avatar
YiYi Xu committed
266
        self._step_index = None
267
        self._begin_index = None
268
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
YiYi Xu's avatar
YiYi Xu committed
269

270
    @property
271
272
273
274
275
276
277
278
279
    def init_noise_sigma(self) -> Union[float, torch.Tensor]:
        """
        The standard deviation of the initial noise distribution.

        Returns:
            `float` or `torch.Tensor`:
                The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
                the timestep spacing configuration.
        """
280
        max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
281
        if self.config.timestep_spacing in ["linspace", "trailing"]:
282
            return max_sigma
283

284
        return (max_sigma**2 + 1) ** 0.5
285

YiYi Xu's avatar
YiYi Xu committed
286
    @property
287
    def step_index(self) -> Optional[int]:
YiYi Xu's avatar
YiYi Xu committed
288
        """
289
290
291
292
293
        The index counter for current timestep. It will increase by 1 after each scheduler step.

        Returns:
            `int` or `None`:
                The current step index, or `None` if not initialized.
YiYi Xu's avatar
YiYi Xu committed
294
295
296
        """
        return self._step_index

297
    @property
298
    def begin_index(self) -> Optional[int]:
299
300
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
301
302
303
304

        Returns:
            `int` or `None`:
                The begin index for the scheduler, or `None` if not set.
305
306
307
308
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
309
    def set_begin_index(self, begin_index: int = 0) -> None:
310
311
312
313
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
314
            begin_index (`int`, defaults to `0`):
315
316
317
318
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

319
    def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
hlky's avatar
hlky committed
320
        """
321
322
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
hlky's avatar
hlky committed
323
324

        Args:
325
            sample (`torch.Tensor`):
326
327
                The input sample to be scaled.
            timestep (`float` or `torch.Tensor`):
328
                The current timestep in the diffusion chain.
hlky's avatar
hlky committed
329
330

        Returns:
331
            `torch.Tensor`:
332
                A scaled input sample, divided by `(sigma**2 + 1) ** 0.5`.
hlky's avatar
hlky committed
333
        """
YiYi Xu's avatar
YiYi Xu committed
334
335
        if self.step_index is None:
            self._init_step_index(timestep)
336

YiYi Xu's avatar
YiYi Xu committed
337
        sigma = self.sigmas[self.step_index]
hlky's avatar
hlky committed
338
        sample = sample / ((sigma**2 + 1) ** 0.5)
339

hlky's avatar
hlky committed
340
341
342
        self.is_scale_input_called = True
        return sample

343
344
    def set_timesteps(
        self,
345
346
        num_inference_steps: Optional[int] = None,
        device: Optional[Union[str, torch.device]] = None,
347
348
        timesteps: Optional[List[int]] = None,
        sigmas: Optional[List[float]] = None,
349
    ) -> None:
hlky's avatar
hlky committed
350
        """
351
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
hlky's avatar
hlky committed
352
353

        Args:
354
355
356
            num_inference_steps (`int`, *optional*):
                The number of diffusion steps used when generating samples with a pre-trained model. If `None`,
                `timesteps` or `sigmas` must be provided.
357
358
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
359
360
361
362
363
            timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
                based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
                must be `None`, and `timestep_spacing` attribute will be ignored.
            sigmas (`List[float]`, *optional*):
364
365
366
                Custom sigmas used to support arbitrary timesteps schedule. If `None`, timesteps and sigmas will be
                generated based on the relevant scheduler attributes. If `sigmas` is passed, `num_inference_steps` and
                `timesteps` must be `None`, and the timesteps will be generated based on the custom sigmas schedule.
hlky's avatar
hlky committed
367
368
        """

369
370
371
372
373
374
375
376
        if timesteps is not None and sigmas is not None:
            raise ValueError("Only one of `timesteps` or `sigmas` should be set.")
        if num_inference_steps is None and timesteps is None and sigmas is None:
            raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.")
        if num_inference_steps is not None and (timesteps is not None or sigmas is not None):
            raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
        if timesteps is not None and self.config.use_karras_sigmas:
            raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
377
378
        if timesteps is not None and self.config.use_exponential_sigmas:
            raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
379
380
        if timesteps is not None and self.config.use_beta_sigmas:
            raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
381
382
383
384
385
        if (
            timesteps is not None
            and self.config.timestep_type == "continuous"
            and self.config.prediction_type == "v_prediction"
        ):
386
            raise ValueError(
387
                "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
388
389
            )

390
391
392
        if num_inference_steps is None:
            num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
        self.num_inference_steps = num_inference_steps
393

394
395
396
397
        if sigmas is not None:
            log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
            sigmas = np.array(sigmas).astype(np.float32)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
398

399
400
401
402
        else:
            if timesteps is not None:
                timesteps = np.array(timesteps).astype(np.float32)
            else:
Quentin Gallouédec's avatar
Quentin Gallouédec committed
403
                # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
                if self.config.timestep_spacing == "linspace":
                    timesteps = np.linspace(
                        0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
                    )[::-1].copy()
                elif self.config.timestep_spacing == "leading":
                    step_ratio = self.config.num_train_timesteps // self.num_inference_steps
                    # creates integer timesteps by multiplying by ratio
                    # casting to int to avoid issues when num_inference_step is power of 3
                    timesteps = (
                        (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
                    )
                    timesteps += self.config.steps_offset
                elif self.config.timestep_spacing == "trailing":
                    step_ratio = self.config.num_train_timesteps / self.num_inference_steps
                    # creates integer timesteps by multiplying by ratio
                    # casting to int to avoid issues when num_inference_step is power of 3
                    timesteps = (
                        (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
                    )
                    timesteps -= 1
                else:
                    raise ValueError(
                        f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
                    )

            sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
            log_sigmas = np.log(sigmas)
            if self.config.interpolation_type == "linear":
                sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
            elif self.config.interpolation_type == "log_linear":
                sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
            else:
                raise ValueError(
                    f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
                    " 'linear' or 'log_linear'"
                )

            if self.config.use_karras_sigmas:
                sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
                timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

445
            elif self.config.use_exponential_sigmas:
446
                sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
447
448
                timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

449
            elif self.config.use_beta_sigmas:
450
                sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
451
452
                timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

453
454
455
456
457
458
459
460
461
462
            if self.config.final_sigmas_type == "sigma_min":
                sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
            elif self.config.final_sigmas_type == "zero":
                sigma_last = 0
            else:
                raise ValueError(
                    f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
                )

            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
463

Suraj Patil's avatar
Suraj Patil committed
464
        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
YiYi Xu's avatar
YiYi Xu committed
465

Suraj Patil's avatar
Suraj Patil committed
466
467
        # TODO: Support the full EDM scalings for all prediction types and timestep types
        if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
468
            self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
Suraj Patil's avatar
Suraj Patil committed
469
470
471
        else:
            self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)

YiYi Xu's avatar
YiYi Xu committed
472
        self._step_index = None
473
        self._begin_index = None
474
        self.sigmas = sigmas.to("cpu")  # to avoid too much CPU/GPU communication
hlky's avatar
hlky committed
475

476
477
478
479
480
481
482
483
484
485
486
487
488
489
    def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
        """
        Convert sigma values to corresponding timestep values through interpolation.

        Args:
            sigma (`np.ndarray`):
                The sigma value(s) to convert to timestep(s).
            log_sigmas (`np.ndarray`):
                The logarithm of the sigma schedule used for interpolation.

        Returns:
            `np.ndarray`:
                The interpolated timestep value(s) corresponding to the input sigma(s).
        """
490
        # get log sigma
491
        log_sigma = np.log(np.maximum(sigma, 1e-10))
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
        """
        Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
        Models](https://huggingface.co/papers/2206.00364).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following the Karras noise schedule.
        """
528

Suraj Patil's avatar
Suraj Patil committed
529
530
531
532
533
534
535
536
537
538
539
540
541
542
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
543
544

        rho = 7.0  # 7.0 is the value used in the paper
545
        ramp = np.linspace(0, 1, num_inference_steps)
546
547
548
549
550
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

551
552
    # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
553
554
555
556
557
558
559
560
561
562
563
564
565
        """
        Construct an exponential noise schedule.

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following an exponential schedule.
        """
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

582
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
583
584
        return sigmas

585
586
587
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        """
        Construct a beta noise schedule as proposed in [Beta Sampling is All You
        Need](https://huggingface.co/papers/2407.12173).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.
            alpha (`float`, *optional*, defaults to `0.6`):
                The alpha parameter for the beta distribution.
            beta (`float`, *optional*, defaults to `0.6`):
                The beta parameter for the beta distribution.

        Returns:
            `torch.Tensor`:
                The converted sigma values following a beta distribution schedule.
        """
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

622
        sigmas = np.array(
623
624
625
626
627
628
629
630
631
632
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    def index_for_timestep(
        self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
    ) -> int:
        """
        Find the index of a given timestep in the timestep schedule.

        Args:
            timestep (`float` or `torch.Tensor`):
                The timestep value to find in the schedule.
            schedule_timesteps (`torch.Tensor`, *optional*):
                The timestep schedule to search in. If `None`, uses `self.timesteps`.

        Returns:
            `int`:
                The index of the timestep in the schedule. For the very first step, returns the second index if
                multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
        """
650
651
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
YiYi Xu's avatar
YiYi Xu committed
652

653
        indices = (schedule_timesteps == timestep).nonzero()
YiYi Xu's avatar
YiYi Xu committed
654
655
656
657
658

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
659
        pos = 1 if len(indices) > 1 else 0
YiYi Xu's avatar
YiYi Xu committed
660

661
662
        return indices[pos].item()

663
664
665
666
667
668
669
670
    def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
        """
        Initialize the step index for the scheduler based on the given timestep.

        Args:
            timestep (`float` or `torch.Tensor`):
                The current timestep to initialize the step index from.
        """
671
672
673
674
675
676
        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
YiYi Xu's avatar
YiYi Xu committed
677

hlky's avatar
hlky committed
678
679
    def step(
        self,
680
681
682
        model_output: torch.Tensor,
        timestep: Union[float, torch.Tensor],
        sample: torch.Tensor,
hlky's avatar
hlky committed
683
684
685
686
687
688
689
690
        s_churn: float = 0.0,
        s_tmin: float = 0.0,
        s_tmax: float = float("inf"),
        s_noise: float = 1.0,
        generator: Optional[torch.Generator] = None,
        return_dict: bool = True,
    ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
        """
691
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
hlky's avatar
hlky committed
692
693
694
        process from the learned model outputs (most often the predicted noise).

        Args:
695
            model_output (`torch.Tensor`):
696
697
                The direct output from the learned diffusion model.
            timestep (`float` or `torch.Tensor`):
698
                The current discrete timestep in the diffusion chain.
699
            sample (`torch.Tensor`):
700
                A current instance of a sample created by the diffusion process.
701
702
703
704
705
706
707
708
709
710
            s_churn (`float`, *optional*, defaults to `0.0`):
                Stochasticity parameter that controls the amount of noise added during sampling. Higher values increase
                randomness.
            s_tmin (`float`, *optional*, defaults to `0.0`):
                Minimum timestep threshold for applying stochasticity. Only timesteps above this value will have noise
                added.
            s_tmax (`float`, *optional*, defaults to `inf`):
                Maximum timestep threshold for applying stochasticity. Only timesteps below this value will have noise
                added.
            s_noise (`float`, *optional*, defaults to `1.0`):
711
712
                Scaling factor for noise added to the sample.
            generator (`torch.Generator`, *optional*):
713
714
                A random number generator for reproducible sampling.
            return_dict (`bool`, *optional*, defaults to `True`):
715
716
                Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
                tuple.
hlky's avatar
hlky committed
717
718

        Returns:
719
            [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
720
721
722
                If `return_dict` is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
                returned, otherwise a tuple is returned where the first element is the sample tensor and the second
                element is the predicted original sample.
hlky's avatar
hlky committed
723
724
        """

725
        if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
hlky's avatar
hlky committed
726
            raise ValueError(
Patrick von Platen's avatar
Patrick von Platen committed
727
728
729
730
731
                (
                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
                    " one of the `scheduler.timesteps` as a timestep."
                ),
hlky's avatar
hlky committed
732
733
734
            )

        if not self.is_scale_input_called:
735
            logger.warning(
hlky's avatar
hlky committed
736
737
738
739
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

YiYi Xu's avatar
YiYi Xu committed
740
741
        if self.step_index is None:
            self._init_step_index(timestep)
hlky's avatar
hlky committed
742

743
744
745
        # Upcast to avoid precision issues when computing prev_sample
        sample = sample.to(torch.float32)

YiYi Xu's avatar
YiYi Xu committed
746
        sigma = self.sigmas[self.step_index]
hlky's avatar
hlky committed
747
748
749
750
751
752

        gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0

        sigma_hat = sigma * (gamma + 1)

        if gamma > 0:
753
754
755
756
            noise = randn_tensor(
                model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
            )
            eps = noise * s_noise
hlky's avatar
hlky committed
757
758
759
            sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
760
761
762
        # NOTE: "original_sample" should not be an expected prediction_type but is left in for
        # backwards compatibility
        if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
763
764
            pred_original_sample = model_output
        elif self.config.prediction_type == "epsilon":
Suraj Patil's avatar
Suraj Patil committed
765
            pred_original_sample = sample - sigma_hat * model_output
766
        elif self.config.prediction_type == "v_prediction":
Suraj Patil's avatar
Suraj Patil committed
767
            # denoised = model_output * c_out + input * c_skip
Suraj Patil's avatar
Suraj Patil committed
768
769
770
            pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
        else:
            raise ValueError(
771
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
Suraj Patil's avatar
Suraj Patil committed
772
            )
hlky's avatar
hlky committed
773
774
775
776

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma_hat

YiYi Xu's avatar
YiYi Xu committed
777
        dt = self.sigmas[self.step_index + 1] - sigma_hat
hlky's avatar
hlky committed
778
779
780

        prev_sample = sample + derivative * dt

781
782
783
        # Cast sample back to model compatible dtype
        prev_sample = prev_sample.to(model_output.dtype)

YiYi Xu's avatar
YiYi Xu committed
784
785
786
        # upon completion increase step index by one
        self._step_index += 1

hlky's avatar
hlky committed
787
        if not return_dict:
788
789
790
791
            return (
                prev_sample,
                pred_original_sample,
            )
hlky's avatar
hlky committed
792
793
794
795
796

        return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

    def add_noise(
        self,
797
798
799
800
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
    ) -> torch.Tensor:
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
        """
        Add noise to the original samples according to the noise schedule at the specified timesteps.

        Args:
            original_samples (`torch.Tensor`):
                The original samples to which noise will be added.
            noise (`torch.Tensor`):
                The noise tensor to add to the original samples.
            timesteps (`torch.Tensor`):
                The timesteps at which to add noise, determining the noise level from the schedule.

        Returns:
            `torch.Tensor`:
                The noisy samples with added noise scaled according to the timestep schedule.
        """
hlky's avatar
hlky committed
816
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
817
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
hlky's avatar
hlky committed
818
819
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
820
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
hlky's avatar
hlky committed
821
822
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
823
            schedule_timesteps = self.timesteps.to(original_samples.device)
hlky's avatar
hlky committed
824
825
            timesteps = timesteps.to(original_samples.device)

826
827
828
        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
829
830
831
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
832
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
833
            # add noise is called before first denoising step to create initial latent(img2img)
834
            step_indices = [self.begin_index] * timesteps.shape[0]
hlky's avatar
hlky committed
835

836
        sigma = sigmas[step_indices].flatten()
hlky's avatar
hlky committed
837
838
839
840
841
842
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
        return noisy_samples

843
    def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
        """
        Compute the velocity prediction for the given sample and noise at the specified timesteps.

        This method implements the velocity prediction used in v-prediction models, which predicts a linear combination
        of the sample and noise.

        Args:
            sample (`torch.Tensor`):
                The input sample for which to compute the velocity.
            noise (`torch.Tensor`):
                The noise tensor corresponding to the sample.
            timesteps (`torch.Tensor`):
                The timesteps at which to compute the velocity.

        Returns:
            `torch.Tensor`:
                The velocity prediction computed as `sqrt(alpha_prod) * noise - sqrt(1 - alpha_prod) * sample`.
        """
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
        if (
            isinstance(timesteps, int)
            or isinstance(timesteps, torch.IntTensor)
            or isinstance(timesteps, torch.LongTensor)
        ):
            raise ValueError(
                (
                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                    " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
                    " one of the `scheduler.timesteps` as a timestep."
                ),
            )

        if sample.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
            schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
            timesteps = timesteps.to(sample.device, dtype=torch.float32)
        else:
            schedule_timesteps = self.timesteps.to(sample.device)
            timesteps = timesteps.to(sample.device)

        step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
        alphas_cumprod = self.alphas_cumprod.to(sample)
        sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(sample.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
        return velocity

898
    def __len__(self) -> int:
hlky's avatar
hlky committed
899
        return self.config.num_train_timesteps