scheduling_lms_discrete.py 27 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import math
15
import warnings
16
from dataclasses import dataclass
17
from typing import List, Literal, Optional, Tuple, Union
18
19

import numpy as np
20
import scipy.stats
21
22
23
24
import torch
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
Kashif Rasul's avatar
Kashif Rasul committed
25
26
from ..utils import BaseOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
27
28
29


@dataclass
30
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
31
32
class LMSDiscreteSchedulerOutput(BaseOutput):
    """
33
    Output class for the scheduler's `step` function output.
34
35

    Args:
36
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
            denoising loop.
39
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
42
43
            `pred_original_sample` can be used to preview progress or for guidance.
    """

44
45
    prev_sample: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
46
47


48
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
49
def betas_for_alpha_bar(
50
51
52
53
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
54
55
56
57
58
59
60
61
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
62
63
64
65
66
67
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
68
69

    Returns:
70
71
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
72
    """
YiYi Xu's avatar
YiYi Xu committed
73
    if alpha_transform_type == "cosine":
74

YiYi Xu's avatar
YiYi Xu committed
75
76
77
78
79
80
81
82
83
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
84
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
85
86
87
88
89

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
90
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
91
92
93
    return torch.tensor(betas, dtype=torch.float32)


94
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
95
    """
96
    A linear multistep scheduler for discrete beta schedules.
97

98
99
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
100

101
    Args:
102
103
104
105
106
107
108
109
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
110
            `linear` or `scaled_linear`.
111
112
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
113
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
114
115
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
116
117
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
118
119
120
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
121
122
123
124
125
126
127
128
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
129
            An offset added to the inference steps, as required by some model families.
130
131
    """

Kashif Rasul's avatar
Kashif Rasul committed
132
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
133
    order = 1
134

135
136
137
    @register_to_config
    def __init__(
        self,
138
139
140
141
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
142
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
143
        use_karras_sigmas: Optional[bool] = False,
144
        use_exponential_sigmas: Optional[bool] = False,
145
        use_beta_sigmas: Optional[bool] = False,
146
        prediction_type: str = "epsilon",
147
148
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
149
    ):
150
151
152
153
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
154
        if trained_betas is not None:
155
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
156
        elif beta_schedule == "linear":
157
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
158
159
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
160
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
161
162
163
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
164
        else:
165
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
166
167

        self.alphas = 1.0 - self.betas
168
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
169

170
171
172
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
        self.sigmas = torch.from_numpy(sigmas)
173
174
175

        # setable values
        self.num_inference_steps = None
176
177
        self.use_karras_sigmas = use_karras_sigmas
        self.set_timesteps(num_train_timesteps, None)
178
        self.derivatives = []
179
180
        self.is_scale_input_called = False

YiYi Xu's avatar
YiYi Xu committed
181
        self._step_index = None
182
        self._begin_index = None
183
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
YiYi Xu's avatar
YiYi Xu committed
184

185
186
187
188
189
190
191
192
    @property
    def init_noise_sigma(self):
        # standard deviation of the initial noise distribution
        if self.config.timestep_spacing in ["linspace", "trailing"]:
            return self.sigmas.max()

        return (self.sigmas.max() ** 2 + 1) ** 0.5

YiYi Xu's avatar
YiYi Xu committed
193
194
195
    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
196
        The index counter for current timestep. It will increase 1 after each scheduler step.
YiYi Xu's avatar
YiYi Xu committed
197
198
199
        """
        return self._step_index

200
201
202
203
204
205
206
207
208
209
210
211
212
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
213
            begin_index (`int`, defaults to `0`):
214
215
216
217
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

218
    def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
219
        """
220
221
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.
222
223

        Args:
224
            sample (`torch.Tensor`):
225
                The input sample.
226
            timestep (`float` or `torch.Tensor`):
227
                The current timestep in the diffusion chain.
228
229

        Returns:
230
            `torch.Tensor`:
231
                A scaled input sample.
232
        """
YiYi Xu's avatar
YiYi Xu committed
233
234
235
236
237

        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]
238
239
240
        sample = sample / ((sigma**2 + 1) ** 0.5)
        self.is_scale_input_called = True
        return sample
241
242
243

    def get_lms_coefficient(self, order, t, current_order):
        """
244
        Compute the linear multistep coefficient.
245
246

        Args:
247
248
249
            order ():
            t ():
            current_order ():
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        """

        def lms_derivative(tau):
            prod = 1.0
            for k in range(order):
                if current_order == k:
                    continue
                prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
            return prod

        integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]

        return integrated_coeff

264
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
265
        """
266
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
267
268
269

        Args:
            num_inference_steps (`int`):
270
271
272
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
273
        """
274
275
        self.num_inference_steps = num_inference_steps

Quentin Gallouédec's avatar
Quentin Gallouédec committed
276
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
277
        if self.config.timestep_spacing == "linspace":
YiYi Xu's avatar
YiYi Xu committed
278
            timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
279
280
281
282
283
284
                ::-1
            ].copy()
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
285
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
286
287
288
289
290
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
291
            timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
292
293
294
295
296
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
297

298
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
299
        log_sigmas = np.log(sigmas)
300
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
301

302
        if self.config.use_karras_sigmas:
303
304
            sigmas = self._convert_to_karras(in_sigmas=sigmas)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
305
        elif self.config.use_exponential_sigmas:
306
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
307
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
308
        elif self.config.use_beta_sigmas:
309
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
310
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
311

312
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
313

314
        self.sigmas = torch.from_numpy(sigmas).to(device=device)
hlky's avatar
hlky committed
315
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
YiYi Xu's avatar
YiYi Xu committed
316
        self._step_index = None
317
        self._begin_index = None
318
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
319
320
321

        self.derivatives = []

322
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    def index_for_timestep(
        self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
    ) -> int:
        """
        Find the index of a given timestep in the timestep schedule.

        Args:
            timestep (`float` or `torch.Tensor`):
                The timestep value to find in the schedule.
            schedule_timesteps (`torch.Tensor`, *optional*):
                The timestep schedule to search in. If `None`, uses `self.timesteps`.

        Returns:
            `int`:
                The index of the timestep in the schedule. For the very first step, returns the second index if
                multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
        """
340
341
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
YiYi Xu's avatar
YiYi Xu committed
342

343
        indices = (schedule_timesteps == timestep).nonzero()
YiYi Xu's avatar
YiYi Xu committed
344
345
346
347
348

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
349
350
351
        pos = 1 if len(indices) > 1 else 0

        return indices[pos].item()
YiYi Xu's avatar
YiYi Xu committed
352

353
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
354
355
356
357
358
359
360
361
    def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
        """
        Initialize the step index for the scheduler based on the given timestep.

        Args:
            timestep (`float` or `torch.Tensor`):
                The current timestep to initialize the step index from.
        """
362
363
364
365
366
367
        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
YiYi Xu's avatar
YiYi Xu committed
368

369
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
370
    def _sigma_to_t(self, sigma, log_sigmas):
371
372
373
374
375
376
377
378
379
380
381
382
383
        """
        Convert sigma values to corresponding timestep values through interpolation.

        Args:
            sigma (`np.ndarray`):
                The sigma value(s) to convert to timestep(s).
            log_sigmas (`np.ndarray`):
                The logarithm of the sigma schedule used for interpolation.

        Returns:
            `np.ndarray`:
                The interpolated timestep value(s) corresponding to the input sigma(s).
        """
384
        # get log sigma
385
        log_sigma = np.log(np.maximum(sigma, 1e-10))
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

406
    # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
407
    def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
408
409
410
411
412
413
414
415
416
417
418
419
        """Constructs the noise schedule of Karras et al. (2022)."""

        sigma_min: float = in_sigmas[-1].item()
        sigma_max: float = in_sigmas[0].item()

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, self.num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

420
421
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
422
423
424
425
426
427
428
429
430
431
432
433
434
        """
        Construct an exponential noise schedule.

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following an exponential schedule.
        """
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

451
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
452
453
        return sigmas

454
455
456
457
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        """
        Construct a beta noise schedule as proposed in [Beta Sampling is All You
        Need](https://huggingface.co/papers/2407.12173).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.
            alpha (`float`, *optional*, defaults to `0.6`):
                The alpha parameter for the beta distribution.
            beta (`float`, *optional*, defaults to `0.6`):
                The beta parameter for the beta distribution.

        Returns:
            `torch.Tensor`:
                The converted sigma values following a beta distribution schedule.
        """
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

492
        sigmas = np.array(
493
494
495
496
497
498
499
500
501
502
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

503
504
    def step(
        self,
505
506
507
        model_output: torch.Tensor,
        timestep: Union[float, torch.Tensor],
        sample: torch.Tensor,
508
        order: int = 4,
509
        return_dict: bool = True,
510
    ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
511
        """
512
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
513
514
515
        process from the learned model outputs (most often the predicted noise).

        Args:
516
            model_output (`torch.Tensor`):
517
                The direct output from learned diffusion model.
518
            timestep (`float` or `torch.Tensor`):
519
                The current discrete timestep in the diffusion chain.
520
            sample (`torch.Tensor`):
521
522
523
524
525
                A current instance of a sample created by the diffusion process.
            order (`int`, defaults to 4):
                The order of the linear multistep method.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
526
527

        Returns:
528
529
530
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
531
532

        """
533
534
535
536
537
538
        if not self.is_scale_input_called:
            warnings.warn(
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

YiYi Xu's avatar
YiYi Xu committed
539
540
541
542
        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]
543
544

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
545
546
547
548
549
        if self.config.prediction_type == "epsilon":
            pred_original_sample = sample - sigma * model_output
        elif self.config.prediction_type == "v_prediction":
            # * c_out + input * c_skip
            pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
550
551
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
552
553
554
555
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )
556
557
558
559
560
561
562
563

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma
        self.derivatives.append(derivative)
        if len(self.derivatives) > order:
            self.derivatives.pop(0)

        # 3. Compute linear multistep coefficients
YiYi Xu's avatar
YiYi Xu committed
564
565
        order = min(self.step_index + 1, order)
        lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)]
566
567
568
569
570
571

        # 4. Compute previous sample based on the derivatives path
        prev_sample = sample + sum(
            coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
        )

YiYi Xu's avatar
YiYi Xu committed
572
573
574
        # upon completion increase step index by one
        self._step_index += 1

575
        if not return_dict:
576
577
578
579
            return (
                prev_sample,
                pred_original_sample,
            )
580

581
        return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
582

583
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
584
585
    def add_noise(
        self,
586
587
588
589
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
    ) -> torch.Tensor:
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        """
        Add noise to the original samples according to the noise schedule at the specified timesteps.

        Args:
            original_samples (`torch.Tensor`):
                The original samples to which noise will be added.
            noise (`torch.Tensor`):
                The noise tensor to add to the original samples.
            timesteps (`torch.Tensor`):
                The timesteps at which to add noise, determining the noise level from the schedule.

        Returns:
            `torch.Tensor`:
                The noisy samples with added noise scaled according to the timestep schedule.
        """
605
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
606
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
607
608
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
609
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
610
611
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
612
            schedule_timesteps = self.timesteps.to(original_samples.device)
613
            timesteps = timesteps.to(original_samples.device)
614

615
616
617
        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
618
619
620
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
621
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
622
            # add noise is called before first denoising step to create initial latent(img2img)
623
            step_indices = [self.begin_index] * timesteps.shape[0]
624

625
        sigma = sigmas[step_indices].flatten()
626
627
628
629
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
630
631
632
633
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps