scheduling_lms_discrete.py 28.4 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import math
15
import warnings
16
from dataclasses import dataclass
17
from typing import List, Literal, Optional, Tuple, Union
18
19

import numpy as np
20
import scipy.stats
21
22
23
24
import torch
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
Kashif Rasul's avatar
Kashif Rasul committed
25
26
from ..utils import BaseOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
27
28
29


@dataclass
30
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
31
32
class LMSDiscreteSchedulerOutput(BaseOutput):
    """
33
    Output class for the scheduler's `step` function output.
34
35

    Args:
36
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
            denoising loop.
39
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
42
43
            `pred_original_sample` can be used to preview progress or for guidance.
    """

44
45
    prev_sample: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
46
47


48
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
49
def betas_for_alpha_bar(
50
51
52
53
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
54
55
56
57
58
59
60
61
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
62
63
64
65
66
67
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
68
69

    Returns:
70
71
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
72
    """
YiYi Xu's avatar
YiYi Xu committed
73
    if alpha_transform_type == "cosine":
74

YiYi Xu's avatar
YiYi Xu committed
75
76
77
78
79
80
81
82
83
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
84
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
85
86
87
88
89

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
90
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
91
92
93
    return torch.tensor(betas, dtype=torch.float32)


94
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
95
    """
96
    A linear multistep scheduler for discrete beta schedules.
97

98
99
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
100

101
    Args:
102
        num_train_timesteps (`int`, defaults to `1000`):
103
            The number of diffusion steps to train the model.
104
        beta_start (`float`, defaults to `0.0001`):
105
            The starting `beta` value of inference.
106
        beta_end (`float`, defaults to `0.02`):
107
            The final `beta` value.
108
109
        beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
110
111
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
112
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
113
114
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
115
116
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
117
118
119
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
120
        prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`):
121
122
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
123
            Video](https://huggingface.co/papers/2210.02303) paper).
124
        timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
125
126
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
127
        steps_offset (`int`, defaults to `0`):
128
            An offset added to the inference steps, as required by some model families.
129
130
    """

Kashif Rasul's avatar
Kashif Rasul committed
131
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
132
    order = 1
133

134
135
136
    @register_to_config
    def __init__(
        self,
137
138
139
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
140
        beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
141
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
142
        use_karras_sigmas: Optional[bool] = False,
143
        use_exponential_sigmas: Optional[bool] = False,
144
        use_beta_sigmas: Optional[bool] = False,
145
146
        prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
        timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
147
        steps_offset: int = 0,
148
    ):
149
150
151
152
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
153
        if trained_betas is not None:
154
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
155
        elif beta_schedule == "linear":
156
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
157
158
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
159
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
160
161
162
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
163
        else:
164
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
165
166

        self.alphas = 1.0 - self.betas
167
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
168

169
170
171
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
        self.sigmas = torch.from_numpy(sigmas)
172
173
174

        # setable values
        self.num_inference_steps = None
175
176
        self.use_karras_sigmas = use_karras_sigmas
        self.set_timesteps(num_train_timesteps, None)
177
        self.derivatives = []
178
179
        self.is_scale_input_called = False

YiYi Xu's avatar
YiYi Xu committed
180
        self._step_index = None
181
        self._begin_index = None
182
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
YiYi Xu's avatar
YiYi Xu committed
183

184
    @property
185
186
187
188
189
190
191
192
193
    def init_noise_sigma(self) -> Union[float, torch.Tensor]:
        """
        The standard deviation of the initial noise distribution.

        Returns:
            `float` or `torch.Tensor`:
                The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
                the timestep spacing configuration.
        """
194
195
196
197
198
199
        # standard deviation of the initial noise distribution
        if self.config.timestep_spacing in ["linspace", "trailing"]:
            return self.sigmas.max()

        return (self.sigmas.max() ** 2 + 1) ** 0.5

YiYi Xu's avatar
YiYi Xu committed
200
    @property
201
    def step_index(self) -> Optional[int]:
YiYi Xu's avatar
YiYi Xu committed
202
        """
203
204
205
206
207
        The index counter for current timestep. It will increase by 1 after each scheduler step.

        Returns:
            `int` or `None`:
                The current step index, or `None` if not initialized.
YiYi Xu's avatar
YiYi Xu committed
208
209
210
        """
        return self._step_index

211
    @property
212
    def begin_index(self) -> Optional[int]:
213
214
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
215
216
217
218

        Returns:
            `int` or `None`:
                The begin index for the scheduler, or `None` if not set.
219
220
221
222
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
223
    def set_begin_index(self, begin_index: int = 0) -> None:
224
225
226
227
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
228
            begin_index (`int`, defaults to `0`):
229
230
231
232
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

233
    def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
234
        """
235
236
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.
237
238

        Args:
239
            sample (`torch.Tensor`):
240
                The input sample.
241
            timestep (`float` or `torch.Tensor`):
242
                The current timestep in the diffusion chain.
243
244

        Returns:
245
            `torch.Tensor`:
246
                A scaled input sample.
247
        """
YiYi Xu's avatar
YiYi Xu committed
248
249
250
251
252

        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]
253
254
255
        sample = sample / ((sigma**2 + 1) ** 0.5)
        self.is_scale_input_called = True
        return sample
256

257
    def get_lms_coefficient(self, order: int, t: int, current_order: int) -> float:
258
        """
259
        Compute the linear multistep coefficient.
260
261

        Args:
262
263
264
265
266
267
268
269
270
271
            order (`int`):
                The order of the linear multistep method.
            t (`int`):
                The current timestep index.
            current_order (`int`):
                The current order for which to compute the coefficient.

        Returns:
            `float`:
                The computed linear multistep coefficient.
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        """

        def lms_derivative(tau):
            prod = 1.0
            for k in range(order):
                if current_order == k:
                    continue
                prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
            return prod

        integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]

        return integrated_coeff

286
    def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
287
        """
288
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
289
290
291

        Args:
            num_inference_steps (`int`):
292
293
294
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
295
        """
296
297
        self.num_inference_steps = num_inference_steps

Quentin Gallouédec's avatar
Quentin Gallouédec committed
298
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
299
        if self.config.timestep_spacing == "linspace":
YiYi Xu's avatar
YiYi Xu committed
300
            timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
301
302
303
304
305
306
                ::-1
            ].copy()
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
307
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
308
309
310
311
312
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
313
            timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
314
315
316
317
318
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
319

320
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
321
        log_sigmas = np.log(sigmas)
322
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
323

324
        if self.config.use_karras_sigmas:
325
326
            sigmas = self._convert_to_karras(in_sigmas=sigmas)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
327
        elif self.config.use_exponential_sigmas:
328
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
329
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
330
        elif self.config.use_beta_sigmas:
331
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
332
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
333

334
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
335

336
        self.sigmas = torch.from_numpy(sigmas).to(device=device)
hlky's avatar
hlky committed
337
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
YiYi Xu's avatar
YiYi Xu committed
338
        self._step_index = None
339
        self._begin_index = None
340
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
341
342
343

        self.derivatives = []

344
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    def index_for_timestep(
        self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
    ) -> int:
        """
        Find the index of a given timestep in the timestep schedule.

        Args:
            timestep (`float` or `torch.Tensor`):
                The timestep value to find in the schedule.
            schedule_timesteps (`torch.Tensor`, *optional*):
                The timestep schedule to search in. If `None`, uses `self.timesteps`.

        Returns:
            `int`:
                The index of the timestep in the schedule. For the very first step, returns the second index if
                multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
        """
362
363
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
YiYi Xu's avatar
YiYi Xu committed
364

365
        indices = (schedule_timesteps == timestep).nonzero()
YiYi Xu's avatar
YiYi Xu committed
366
367
368
369
370

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
371
372
373
        pos = 1 if len(indices) > 1 else 0

        return indices[pos].item()
YiYi Xu's avatar
YiYi Xu committed
374

375
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
376
377
378
379
380
381
382
383
    def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
        """
        Initialize the step index for the scheduler based on the given timestep.

        Args:
            timestep (`float` or `torch.Tensor`):
                The current timestep to initialize the step index from.
        """
384
385
386
387
388
389
        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
YiYi Xu's avatar
YiYi Xu committed
390

391
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
392
    def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
393
394
395
396
397
398
399
400
401
402
403
404
405
        """
        Convert sigma values to corresponding timestep values through interpolation.

        Args:
            sigma (`np.ndarray`):
                The sigma value(s) to convert to timestep(s).
            log_sigmas (`np.ndarray`):
                The logarithm of the sigma schedule used for interpolation.

        Returns:
            `np.ndarray`:
                The interpolated timestep value(s) corresponding to the input sigma(s).
        """
406
        # get log sigma
407
        log_sigma = np.log(np.maximum(sigma, 1e-10))
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

428
    def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
429
430
431
432
433
434
435
436
437
438
439
440
        """
        Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
        Models](https://huggingface.co/papers/2206.00364).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.

        Returns:
            `torch.Tensor`:
                The converted sigma values following the Karras noise schedule.
        """
441
442
443
444
445
446
447
448
449
450
451

        sigma_min: float = in_sigmas[-1].item()
        sigma_max: float = in_sigmas[0].item()

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, self.num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

452
453
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
454
455
456
457
458
459
460
461
462
463
464
465
466
        """
        Construct an exponential noise schedule.

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following an exponential schedule.
        """
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

483
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
484
485
        return sigmas

486
487
488
489
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        """
        Construct a beta noise schedule as proposed in [Beta Sampling is All You
        Need](https://huggingface.co/papers/2407.12173).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.
            alpha (`float`, *optional*, defaults to `0.6`):
                The alpha parameter for the beta distribution.
            beta (`float`, *optional*, defaults to `0.6`):
                The beta parameter for the beta distribution.

        Returns:
            `torch.Tensor`:
                The converted sigma values following a beta distribution schedule.
        """
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

524
        sigmas = np.array(
525
526
527
528
529
530
531
532
533
534
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

535
536
    def step(
        self,
537
538
539
        model_output: torch.Tensor,
        timestep: Union[float, torch.Tensor],
        sample: torch.Tensor,
540
        order: int = 4,
541
        return_dict: bool = True,
542
    ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
543
        """
544
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
545
546
547
        process from the learned model outputs (most often the predicted noise).

        Args:
548
            model_output (`torch.Tensor`):
549
                The direct output from learned diffusion model.
550
            timestep (`float` or `torch.Tensor`):
551
                The current discrete timestep in the diffusion chain.
552
            sample (`torch.Tensor`):
553
554
555
556
557
                A current instance of a sample created by the diffusion process.
            order (`int`, defaults to 4):
                The order of the linear multistep method.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
558
559

        Returns:
560
561
562
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
563
564

        """
565
566
567
568
569
570
        if not self.is_scale_input_called:
            warnings.warn(
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

YiYi Xu's avatar
YiYi Xu committed
571
572
573
574
        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]
575
576

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
577
578
579
580
581
        if self.config.prediction_type == "epsilon":
            pred_original_sample = sample - sigma * model_output
        elif self.config.prediction_type == "v_prediction":
            # * c_out + input * c_skip
            pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
582
583
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
584
585
586
587
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )
588
589
590
591
592
593
594
595

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma
        self.derivatives.append(derivative)
        if len(self.derivatives) > order:
            self.derivatives.pop(0)

        # 3. Compute linear multistep coefficients
YiYi Xu's avatar
YiYi Xu committed
596
597
        order = min(self.step_index + 1, order)
        lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)]
598
599
600
601
602
603

        # 4. Compute previous sample based on the derivatives path
        prev_sample = sample + sum(
            coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
        )

YiYi Xu's avatar
YiYi Xu committed
604
605
606
        # upon completion increase step index by one
        self._step_index += 1

607
        if not return_dict:
608
609
610
611
            return (
                prev_sample,
                pred_original_sample,
            )
612

613
        return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
614

615
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
616
617
    def add_noise(
        self,
618
619
620
621
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
    ) -> torch.Tensor:
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        """
        Add noise to the original samples according to the noise schedule at the specified timesteps.

        Args:
            original_samples (`torch.Tensor`):
                The original samples to which noise will be added.
            noise (`torch.Tensor`):
                The noise tensor to add to the original samples.
            timesteps (`torch.Tensor`):
                The timesteps at which to add noise, determining the noise level from the schedule.

        Returns:
            `torch.Tensor`:
                The noisy samples with added noise scaled according to the timestep schedule.
        """
637
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
638
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
639
640
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
641
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
642
643
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
644
            schedule_timesteps = self.timesteps.to(original_samples.device)
645
            timesteps = timesteps.to(original_samples.device)
646

647
648
649
        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
650
651
652
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
653
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
654
            # add noise is called before first denoising step to create initial latent(img2img)
655
            step_indices = [self.begin_index] * timesteps.shape[0]
656

657
        sigma = sigmas[step_indices].flatten()
658
659
660
661
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
662
663
        return noisy_samples

664
    def __len__(self) -> int:
665
        return self.config.num_train_timesteps