scheduling_lms_discrete.py 23.8 KB
Newer Older
1
# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import math
15
import warnings
16
from dataclasses import dataclass
17
from typing import List, Optional, Tuple, Union
18
19

import numpy as np
20
import scipy.stats
21
22
23
24
import torch
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
Kashif Rasul's avatar
Kashif Rasul committed
25
26
from ..utils import BaseOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
27
28
29


@dataclass
30
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
31
32
class LMSDiscreteSchedulerOutput(BaseOutput):
    """
33
    Output class for the scheduler's `step` function output.
34
35

    Args:
36
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
            denoising loop.
39
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
42
43
            `pred_original_sample` can be used to preview progress or for guidance.
    """

44
45
    prev_sample: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
46
47


48
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
49
50
51
52
53
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
54
55
56
57
58
59
60
61
62
63
64
65
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
66
67
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
68
69
70
71

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
YiYi Xu's avatar
YiYi Xu committed
72
    if alpha_transform_type == "cosine":
73

YiYi Xu's avatar
YiYi Xu committed
74
75
76
77
78
79
80
81
82
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
83
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
84
85
86
87
88

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
89
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
90
91
92
    return torch.tensor(betas, dtype=torch.float32)


93
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
94
    """
95
    A linear multistep scheduler for discrete beta schedules.
96

97
98
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
99

100
    Args:
101
102
103
104
105
106
107
108
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
109
            `linear` or `scaled_linear`.
110
111
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
112
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
113
114
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
115
116
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
117
118
119
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
120
121
122
123
124
125
126
127
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
128
            An offset added to the inference steps, as required by some model families.
129
130
    """

Kashif Rasul's avatar
Kashif Rasul committed
131
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
132
    order = 1
133

134
135
136
    @register_to_config
    def __init__(
        self,
137
138
139
140
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
141
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
142
        use_karras_sigmas: Optional[bool] = False,
143
        use_exponential_sigmas: Optional[bool] = False,
144
        use_beta_sigmas: Optional[bool] = False,
145
        prediction_type: str = "epsilon",
146
147
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
148
    ):
149
150
151
152
        if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
            raise ValueError(
                "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
            )
153
        if trained_betas is not None:
154
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
155
        elif beta_schedule == "linear":
156
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
157
158
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
159
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
160
161
162
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
163
        else:
164
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
165
166

        self.alphas = 1.0 - self.betas
167
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
168

169
170
171
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
        self.sigmas = torch.from_numpy(sigmas)
172
173
174

        # setable values
        self.num_inference_steps = None
175
176
        self.use_karras_sigmas = use_karras_sigmas
        self.set_timesteps(num_train_timesteps, None)
177
        self.derivatives = []
178
179
        self.is_scale_input_called = False

YiYi Xu's avatar
YiYi Xu committed
180
        self._step_index = None
181
        self._begin_index = None
182
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
YiYi Xu's avatar
YiYi Xu committed
183

184
185
186
187
188
189
190
191
    @property
    def init_noise_sigma(self):
        # standard deviation of the initial noise distribution
        if self.config.timestep_spacing in ["linspace", "trailing"]:
            return self.sigmas.max()

        return (self.sigmas.max() ** 2 + 1) ** 0.5

YiYi Xu's avatar
YiYi Xu committed
192
193
194
    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
195
        The index counter for current timestep. It will increase 1 after each scheduler step.
YiYi Xu's avatar
YiYi Xu committed
196
197
198
        """
        return self._step_index

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

217
    def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
218
        """
219
220
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.
221
222

        Args:
223
            sample (`torch.Tensor`):
224
                The input sample.
225
            timestep (`float` or `torch.Tensor`):
226
                The current timestep in the diffusion chain.
227
228

        Returns:
229
            `torch.Tensor`:
230
                A scaled input sample.
231
        """
YiYi Xu's avatar
YiYi Xu committed
232
233
234
235
236

        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]
237
238
239
        sample = sample / ((sigma**2 + 1) ** 0.5)
        self.is_scale_input_called = True
        return sample
240
241
242

    def get_lms_coefficient(self, order, t, current_order):
        """
243
        Compute the linear multistep coefficient.
244
245

        Args:
246
247
248
            order ():
            t ():
            current_order ():
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        """

        def lms_derivative(tau):
            prod = 1.0
            for k in range(order):
                if current_order == k:
                    continue
                prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
            return prod

        integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]

        return integrated_coeff

263
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
264
        """
265
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
266
267
268

        Args:
            num_inference_steps (`int`):
269
270
271
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
272
        """
273
274
        self.num_inference_steps = num_inference_steps

Quentin Gallouédec's avatar
Quentin Gallouédec committed
275
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
276
        if self.config.timestep_spacing == "linspace":
YiYi Xu's avatar
YiYi Xu committed
277
            timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
278
279
280
281
282
283
                ::-1
            ].copy()
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
284
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
285
286
287
288
289
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
290
            timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
291
292
293
294
295
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
296

297
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
298
        log_sigmas = np.log(sigmas)
299
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
300

301
        if self.config.use_karras_sigmas:
302
303
            sigmas = self._convert_to_karras(in_sigmas=sigmas)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
304
        elif self.config.use_exponential_sigmas:
305
            sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
306
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
307
        elif self.config.use_beta_sigmas:
308
            sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
309
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
310

311
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
312

313
        self.sigmas = torch.from_numpy(sigmas).to(device=device)
hlky's avatar
hlky committed
314
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
YiYi Xu's avatar
YiYi Xu committed
315
        self._step_index = None
316
        self._begin_index = None
317
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
318
319
320

        self.derivatives = []

321
322
323
324
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
YiYi Xu's avatar
YiYi Xu committed
325

326
        indices = (schedule_timesteps == timestep).nonzero()
YiYi Xu's avatar
YiYi Xu committed
327
328
329
330
331

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
332
333
334
        pos = 1 if len(indices) > 1 else 0

        return indices[pos].item()
YiYi Xu's avatar
YiYi Xu committed
335

336
337
338
339
340
341
342
343
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
    def _init_step_index(self, timestep):
        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
YiYi Xu's avatar
YiYi Xu committed
344

345
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
346
347
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
348
        log_sigma = np.log(np.maximum(sigma, 1e-10))
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

369
    # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
370
    def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
371
372
373
374
375
376
377
378
379
380
381
382
        """Constructs the noise schedule of Karras et al. (2022)."""

        sigma_min: float = in_sigmas[-1].item()
        sigma_max: float = in_sigmas[0].item()

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, self.num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
    def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
        """Constructs an exponential noise schedule."""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

402
        sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
403
404
        return sigmas

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
    def _convert_to_beta(
        self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
    ) -> torch.Tensor:
        """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

426
        sigmas = np.array(
427
428
429
430
431
432
433
434
435
436
            [
                sigma_min + (ppf * (sigma_max - sigma_min))
                for ppf in [
                    scipy.stats.beta.ppf(timestep, alpha, beta)
                    for timestep in 1 - np.linspace(0, 1, num_inference_steps)
                ]
            ]
        )
        return sigmas

437
438
    def step(
        self,
439
440
441
        model_output: torch.Tensor,
        timestep: Union[float, torch.Tensor],
        sample: torch.Tensor,
442
        order: int = 4,
443
        return_dict: bool = True,
444
    ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
445
        """
446
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
447
448
449
        process from the learned model outputs (most often the predicted noise).

        Args:
450
            model_output (`torch.Tensor`):
451
                The direct output from learned diffusion model.
452
            timestep (`float` or `torch.Tensor`):
453
                The current discrete timestep in the diffusion chain.
454
            sample (`torch.Tensor`):
455
456
457
458
459
                A current instance of a sample created by the diffusion process.
            order (`int`, defaults to 4):
                The order of the linear multistep method.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
460
461

        Returns:
462
463
464
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
465
466

        """
467
468
469
470
471
472
        if not self.is_scale_input_called:
            warnings.warn(
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

YiYi Xu's avatar
YiYi Xu committed
473
474
475
476
        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]
477
478

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
479
480
481
482
483
        if self.config.prediction_type == "epsilon":
            pred_original_sample = sample - sigma * model_output
        elif self.config.prediction_type == "v_prediction":
            # * c_out + input * c_skip
            pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
484
485
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
486
487
488
489
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )
490
491
492
493
494
495
496
497

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma
        self.derivatives.append(derivative)
        if len(self.derivatives) > order:
            self.derivatives.pop(0)

        # 3. Compute linear multistep coefficients
YiYi Xu's avatar
YiYi Xu committed
498
499
        order = min(self.step_index + 1, order)
        lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)]
500
501
502
503
504
505

        # 4. Compute previous sample based on the derivatives path
        prev_sample = sample + sum(
            coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
        )

YiYi Xu's avatar
YiYi Xu committed
506
507
508
        # upon completion increase step index by one
        self._step_index += 1

509
        if not return_dict:
510
511
512
513
            return (
                prev_sample,
                pred_original_sample,
            )
514

515
        return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
516

517
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
518
519
    def add_noise(
        self,
520
521
522
523
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
    ) -> torch.Tensor:
524
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
525
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
526
527
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
528
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
529
530
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
531
            schedule_timesteps = self.timesteps.to(original_samples.device)
532
            timesteps = timesteps.to(original_samples.device)
533

534
535
536
        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
537
538
539
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
540
        else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
541
            # add noise is called before first denoising step to create initial latent(img2img)
542
            step_indices = [self.begin_index] * timesteps.shape[0]
543

544
        sigma = sigmas[step_indices].flatten()
545
546
547
548
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
549
550
551
552
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps