scheduling_euler_discrete.py 21.2 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
hlky's avatar
hlky committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
hlky's avatar
hlky committed
16
from dataclasses import dataclass
17
from typing import List, Optional, Tuple, Union
hlky's avatar
hlky committed
18
19
20
21
22

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
Dhruv Nair's avatar
Dhruv Nair committed
23
24
from ..utils import BaseOutput, logging
from ..utils.torch_utils import randn_tensor
Kashif Rasul's avatar
Kashif Rasul committed
25
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
hlky's avatar
hlky committed
26
27
28
29
30
31
32
33
34


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
class EulerDiscreteSchedulerOutput(BaseOutput):
    """
35
    Output class for the scheduler's `step` function output.
hlky's avatar
hlky committed
36
37
38

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
hlky's avatar
hlky committed
40
41
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
42
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
hlky's avatar
hlky committed
43
44
45
46
47
48
49
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None


50
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
51
52
53
54
55
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
56
57
58
59
60
61
62
63
64
65
66
67
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
68
69
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
70
71
72
73

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
YiYi Xu's avatar
YiYi Xu committed
74
    if alpha_transform_type == "cosine":
75

YiYi Xu's avatar
YiYi Xu committed
76
77
78
79
80
81
82
83
84
85
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
        raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
86
87
88
89
90

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
91
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
92
93
94
    return torch.tensor(betas, dtype=torch.float32)


hlky's avatar
hlky committed
95
96
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
    """
97
    Euler scheduler.
hlky's avatar
hlky committed
98

99
100
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
hlky's avatar
hlky committed
101
102

    Args:
103
104
105
106
107
108
109
110
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
hlky's avatar
hlky committed
111
            `linear` or `scaled_linear`.
112
113
114
115
116
117
118
119
120
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        interpolation_type(`str`, defaults to `"linear"`, *optional*):
            The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
            `"linear"` or `"log_linear"`.
121
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
122
123
124
125
126
127
128
129
130
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
            An offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
            Diffusion.
hlky's avatar
hlky committed
131
132
    """

Kashif Rasul's avatar
Kashif Rasul committed
133
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
134
    order = 1
135

hlky's avatar
hlky committed
136
137
138
139
140
141
142
    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
143
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
Suraj Patil's avatar
Suraj Patil committed
144
        prediction_type: str = "epsilon",
145
        interpolation_type: str = "linear",
146
        use_karras_sigmas: Optional[bool] = False,
Suraj Patil's avatar
Suraj Patil committed
147
148
        sigma_min: Optional[float] = None,
        sigma_max: Optional[float] = None,
149
        timestep_spacing: str = "linspace",
Suraj Patil's avatar
Suraj Patil committed
150
        timestep_type: str = "discrete",  # can be "discrete" or "continuous"
151
        steps_offset: int = 0,
hlky's avatar
hlky committed
152
153
    ):
        if trained_betas is not None:
154
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
hlky's avatar
hlky committed
155
156
157
158
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
159
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
160
161
162
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
hlky's avatar
hlky committed
163
164
165
166
167
168
169
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
Suraj Patil's avatar
Suraj Patil committed
170
171
172
173
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()

        sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
hlky's avatar
hlky committed
174
175
176

        # setable values
        self.num_inference_steps = None
Suraj Patil's avatar
Suraj Patil committed
177
178
179
180
181
182
183
184
185

        # TODO: Support the full EDM scalings for all prediction types and timestep types
        if timestep_type == "continuous" and prediction_type == "v_prediction":
            self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
        else:
            self.timesteps = timesteps

        self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

hlky's avatar
hlky committed
186
        self.is_scale_input_called = False
187
        self.use_karras_sigmas = use_karras_sigmas
hlky's avatar
hlky committed
188

YiYi Xu's avatar
YiYi Xu committed
189
190
        self._step_index = None

191
192
193
194
195
196
197
198
    @property
    def init_noise_sigma(self):
        # standard deviation of the initial noise distribution
        if self.config.timestep_spacing in ["linspace", "trailing"]:
            return self.sigmas.max()

        return (self.sigmas.max() ** 2 + 1) ** 0.5

YiYi Xu's avatar
YiYi Xu committed
199
200
201
202
203
204
205
    @property
    def step_index(self):
        """
        The index counter for current timestep. It will increae 1 after each scheduler step.
        """
        return self._step_index

hlky's avatar
hlky committed
206
207
208
209
    def scale_model_input(
        self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
    ) -> torch.FloatTensor:
        """
210
211
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
hlky's avatar
hlky committed
212
213

        Args:
214
215
216
217
            sample (`torch.FloatTensor`):
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
hlky's avatar
hlky committed
218
219

        Returns:
220
221
            `torch.FloatTensor`:
                A scaled input sample.
hlky's avatar
hlky committed
222
        """
YiYi Xu's avatar
YiYi Xu committed
223
224
        if self.step_index is None:
            self._init_step_index(timestep)
225

YiYi Xu's avatar
YiYi Xu committed
226
        sigma = self.sigmas[self.step_index]
hlky's avatar
hlky committed
227
        sample = sample / ((sigma**2 + 1) ** 0.5)
228

hlky's avatar
hlky committed
229
230
231
232
233
        self.is_scale_input_called = True
        return sample

    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
        """
234
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
hlky's avatar
hlky committed
235
236
237

        Args:
            num_inference_steps (`int`):
238
239
240
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
hlky's avatar
hlky committed
241
242
243
        """
        self.num_inference_steps = num_inference_steps

244
245
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
        if self.config.timestep_spacing == "linspace":
YiYi Xu's avatar
YiYi Xu committed
246
            timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
247
248
249
250
251
252
                ::-1
            ].copy()
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
253
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
254
255
256
257
258
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
259
            timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
260
261
262
263
264
265
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )

hlky's avatar
hlky committed
266
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
267
        log_sigmas = np.log(sigmas)
268
269
270
271
272
273
274
275
276
277
278

        if self.config.interpolation_type == "linear":
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
        elif self.config.interpolation_type == "log_linear":
            sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp()
        else:
            raise ValueError(
                f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
                " 'linear' or 'log_linear'"
            )

279
        if self.use_karras_sigmas:
280
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
281
282
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

Suraj Patil's avatar
Suraj Patil committed
283
        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
YiYi Xu's avatar
YiYi Xu committed
284

Suraj Patil's avatar
Suraj Patil committed
285
286
287
288
289
290
291
        # TODO: Support the full EDM scalings for all prediction types and timestep types
        if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
            self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
        else:
            self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)

        self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
YiYi Xu's avatar
YiYi Xu committed
292
        self._step_index = None
hlky's avatar
hlky committed
293

294
295
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
296
        log_sigma = np.log(np.maximum(sigma, 1e-10))
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
318
    def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
319
320
        """Constructs the noise schedule of Karras et al. (2022)."""

Suraj Patil's avatar
Suraj Patil committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
335
336

        rho = 7.0  # 7.0 is the value used in the paper
337
        ramp = np.linspace(0, 1, num_inference_steps)
338
339
340
341
342
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

YiYi Xu's avatar
YiYi Xu committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    def _init_step_index(self, timestep):
        if isinstance(timestep, torch.Tensor):
            timestep = timestep.to(self.timesteps.device)

        index_candidates = (self.timesteps == timestep).nonzero()

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        if len(index_candidates) > 1:
            step_index = index_candidates[1]
        else:
            step_index = index_candidates[0]

        self._step_index = step_index.item()

hlky's avatar
hlky committed
360
361
362
363
364
365
366
367
368
369
370
371
372
    def step(
        self,
        model_output: torch.FloatTensor,
        timestep: Union[float, torch.FloatTensor],
        sample: torch.FloatTensor,
        s_churn: float = 0.0,
        s_tmin: float = 0.0,
        s_tmax: float = float("inf"),
        s_noise: float = 1.0,
        generator: Optional[torch.Generator] = None,
        return_dict: bool = True,
    ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
        """
373
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
hlky's avatar
hlky committed
374
375
376
        process from the learned model outputs (most often the predicted noise).

        Args:
377
378
379
380
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
hlky's avatar
hlky committed
381
            sample (`torch.FloatTensor`):
382
383
384
385
386
387
388
389
390
391
392
                A current instance of a sample created by the diffusion process.
            s_churn (`float`):
            s_tmin  (`float`):
            s_tmax  (`float`):
            s_noise (`float`, defaults to 1.0):
                Scaling factor for noise added to the sample.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
                tuple.
hlky's avatar
hlky committed
393
394

        Returns:
395
396
397
            [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
                returned, otherwise a tuple is returned where the first element is the sample tensor.
hlky's avatar
hlky committed
398
399
400
401
402
403
404
405
        """

        if (
            isinstance(timestep, int)
            or isinstance(timestep, torch.IntTensor)
            or isinstance(timestep, torch.LongTensor)
        ):
            raise ValueError(
Patrick von Platen's avatar
Patrick von Platen committed
406
407
408
409
410
                (
                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
                    " one of the `scheduler.timesteps` as a timestep."
                ),
hlky's avatar
hlky committed
411
412
413
            )

        if not self.is_scale_input_called:
414
            logger.warning(
hlky's avatar
hlky committed
415
416
417
418
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

YiYi Xu's avatar
YiYi Xu committed
419
420
        if self.step_index is None:
            self._init_step_index(timestep)
hlky's avatar
hlky committed
421

YiYi Xu's avatar
YiYi Xu committed
422
        sigma = self.sigmas[self.step_index]
hlky's avatar
hlky committed
423
424
425

        gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0

426
427
428
        noise = randn_tensor(
            model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
        )
429

hlky's avatar
hlky committed
430
431
432
433
434
435
436
        eps = noise * s_noise
        sigma_hat = sigma * (gamma + 1)

        if gamma > 0:
            sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
437
438
439
        # NOTE: "original_sample" should not be an expected prediction_type but is left in for
        # backwards compatibility
        if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
440
441
            pred_original_sample = model_output
        elif self.config.prediction_type == "epsilon":
Suraj Patil's avatar
Suraj Patil committed
442
            pred_original_sample = sample - sigma_hat * model_output
443
        elif self.config.prediction_type == "v_prediction":
Suraj Patil's avatar
Suraj Patil committed
444
            # denoised = model_output * c_out + input * c_skip
Suraj Patil's avatar
Suraj Patil committed
445
446
447
            pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
        else:
            raise ValueError(
448
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
Suraj Patil's avatar
Suraj Patil committed
449
            )
hlky's avatar
hlky committed
450
451
452
453

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma_hat

YiYi Xu's avatar
YiYi Xu committed
454
        dt = self.sigmas[self.step_index + 1] - sigma_hat
hlky's avatar
hlky committed
455
456
457

        prev_sample = sample + derivative * dt

YiYi Xu's avatar
YiYi Xu committed
458
459
460
        # upon completion increase step index by one
        self._step_index += 1

hlky's avatar
hlky committed
461
462
463
464
465
466
467
468
469
470
471
472
        if not return_dict:
            return (prev_sample,)

        return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

    def add_noise(
        self,
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.FloatTensor,
    ) -> torch.FloatTensor:
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
473
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
hlky's avatar
hlky committed
474
475
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
476
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
hlky's avatar
hlky committed
477
478
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
479
            schedule_timesteps = self.timesteps.to(original_samples.device)
hlky's avatar
hlky committed
480
481
            timesteps = timesteps.to(original_samples.device)

Anton Lozhkov's avatar
Anton Lozhkov committed
482
        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
hlky's avatar
hlky committed
483

484
        sigma = sigmas[step_indices].flatten()
hlky's avatar
hlky committed
485
486
487
488
489
490
491
492
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps