scheduling_euler_discrete.py 19.8 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
hlky's avatar
hlky committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
hlky's avatar
hlky committed
16
from dataclasses import dataclass
17
from typing import List, Optional, Tuple, Union
hlky's avatar
hlky committed
18
19
20
21
22

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
Kashif Rasul's avatar
Kashif Rasul committed
23
24
from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
hlky's avatar
hlky committed
25
26
27
28
29
30
31
32
33


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
class EulerDiscreteSchedulerOutput(BaseOutput):
    """
34
    Output class for the scheduler's `step` function output.
hlky's avatar
hlky committed
35
36
37

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
hlky's avatar
hlky committed
39
40
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
hlky's avatar
hlky committed
42
43
44
45
46
47
48
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None


49
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
50
51
52
53
54
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
55
56
57
58
59
60
61
62
63
64
65
66
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
67
68
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
69
70
71
72

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
YiYi Xu's avatar
YiYi Xu committed
73
    if alpha_transform_type == "cosine":
74

YiYi Xu's avatar
YiYi Xu committed
75
76
77
78
79
80
81
82
83
84
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
        raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
85
86
87
88
89

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
90
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
91
92
93
    return torch.tensor(betas, dtype=torch.float32)


hlky's avatar
hlky committed
94
95
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
    """
96
    Euler scheduler.
hlky's avatar
hlky committed
97

98
99
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
hlky's avatar
hlky committed
100
101

    Args:
102
103
104
105
106
107
108
109
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
hlky's avatar
hlky committed
110
            `linear` or `scaled_linear`.
111
112
113
114
115
116
117
118
119
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        interpolation_type(`str`, defaults to `"linear"`, *optional*):
            The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
            `"linear"` or `"log_linear"`.
120
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
121
122
123
124
125
126
127
128
129
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
            An offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
            Diffusion.
hlky's avatar
hlky committed
130
131
    """

Kashif Rasul's avatar
Kashif Rasul committed
132
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
133
    order = 1
134

hlky's avatar
hlky committed
135
136
137
138
139
140
141
    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
142
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
Suraj Patil's avatar
Suraj Patil committed
143
        prediction_type: str = "epsilon",
144
        interpolation_type: str = "linear",
145
        use_karras_sigmas: Optional[bool] = False,
146
147
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
hlky's avatar
hlky committed
148
149
    ):
        if trained_betas is not None:
150
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
hlky's avatar
hlky committed
151
152
153
154
155
156
157
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
158
159
160
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
hlky's avatar
hlky committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
        self.sigmas = torch.from_numpy(sigmas)

        # setable values
        self.num_inference_steps = None
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
        self.is_scale_input_called = False
176
        self.use_karras_sigmas = use_karras_sigmas
hlky's avatar
hlky committed
177

YiYi Xu's avatar
YiYi Xu committed
178
179
        self._step_index = None

180
181
182
183
184
185
186
187
    @property
    def init_noise_sigma(self):
        # standard deviation of the initial noise distribution
        if self.config.timestep_spacing in ["linspace", "trailing"]:
            return self.sigmas.max()

        return (self.sigmas.max() ** 2 + 1) ** 0.5

YiYi Xu's avatar
YiYi Xu committed
188
189
190
191
192
193
194
    @property
    def step_index(self):
        """
        The index counter for current timestep. It will increae 1 after each scheduler step.
        """
        return self._step_index

hlky's avatar
hlky committed
195
196
197
198
    def scale_model_input(
        self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
    ) -> torch.FloatTensor:
        """
199
200
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
hlky's avatar
hlky committed
201
202

        Args:
203
204
205
206
            sample (`torch.FloatTensor`):
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
hlky's avatar
hlky committed
207
208

        Returns:
209
210
            `torch.FloatTensor`:
                A scaled input sample.
hlky's avatar
hlky committed
211
        """
YiYi Xu's avatar
YiYi Xu committed
212
213
        if self.step_index is None:
            self._init_step_index(timestep)
214

YiYi Xu's avatar
YiYi Xu committed
215
        sigma = self.sigmas[self.step_index]
hlky's avatar
hlky committed
216
        sample = sample / ((sigma**2 + 1) ** 0.5)
217

hlky's avatar
hlky committed
218
219
220
221
222
        self.is_scale_input_called = True
        return sample

    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
        """
223
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
hlky's avatar
hlky committed
224
225
226

        Args:
            num_inference_steps (`int`):
227
228
229
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
hlky's avatar
hlky committed
230
231
232
        """
        self.num_inference_steps = num_inference_steps

233
234
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
        if self.config.timestep_spacing == "linspace":
YiYi Xu's avatar
YiYi Xu committed
235
            timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
236
237
238
239
240
241
                ::-1
            ].copy()
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
242
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
243
244
245
246
247
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
248
            timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
249
250
251
252
253
254
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )

hlky's avatar
hlky committed
255
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
256
        log_sigmas = np.log(sigmas)
257
258
259
260
261
262
263
264
265
266
267

        if self.config.interpolation_type == "linear":
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
        elif self.config.interpolation_type == "log_linear":
            sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp()
        else:
            raise ValueError(
                f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
                " 'linear' or 'log_linear'"
            )

268
        if self.use_karras_sigmas:
269
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
270
271
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

hlky's avatar
hlky committed
272
273
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
        self.sigmas = torch.from_numpy(sigmas).to(device=device)
YiYi Xu's avatar
YiYi Xu committed
274
275
276

        self.timesteps = torch.from_numpy(timesteps).to(device=device)
        self._step_index = None
hlky's avatar
hlky committed
277

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
        log_sigma = np.log(sigma)

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
302
    def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
303
304
305
306
307
308
        """Constructs the noise schedule of Karras et al. (2022)."""

        sigma_min: float = in_sigmas[-1].item()
        sigma_max: float = in_sigmas[0].item()

        rho = 7.0  # 7.0 is the value used in the paper
309
        ramp = np.linspace(0, 1, num_inference_steps)
310
311
312
313
314
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

YiYi Xu's avatar
YiYi Xu committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def _init_step_index(self, timestep):
        if isinstance(timestep, torch.Tensor):
            timestep = timestep.to(self.timesteps.device)

        index_candidates = (self.timesteps == timestep).nonzero()

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        if len(index_candidates) > 1:
            step_index = index_candidates[1]
        else:
            step_index = index_candidates[0]

        self._step_index = step_index.item()

hlky's avatar
hlky committed
332
333
334
335
336
337
338
339
340
341
342
343
344
    def step(
        self,
        model_output: torch.FloatTensor,
        timestep: Union[float, torch.FloatTensor],
        sample: torch.FloatTensor,
        s_churn: float = 0.0,
        s_tmin: float = 0.0,
        s_tmax: float = float("inf"),
        s_noise: float = 1.0,
        generator: Optional[torch.Generator] = None,
        return_dict: bool = True,
    ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
        """
345
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
hlky's avatar
hlky committed
346
347
348
        process from the learned model outputs (most often the predicted noise).

        Args:
349
350
351
352
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
hlky's avatar
hlky committed
353
            sample (`torch.FloatTensor`):
354
355
356
357
358
359
360
361
362
363
364
                A current instance of a sample created by the diffusion process.
            s_churn (`float`):
            s_tmin  (`float`):
            s_tmax  (`float`):
            s_noise (`float`, defaults to 1.0):
                Scaling factor for noise added to the sample.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
                tuple.
hlky's avatar
hlky committed
365
366

        Returns:
367
368
369
            [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
                returned, otherwise a tuple is returned where the first element is the sample tensor.
hlky's avatar
hlky committed
370
371
372
373
374
375
376
377
        """

        if (
            isinstance(timestep, int)
            or isinstance(timestep, torch.IntTensor)
            or isinstance(timestep, torch.LongTensor)
        ):
            raise ValueError(
Patrick von Platen's avatar
Patrick von Platen committed
378
379
380
381
382
                (
                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
                    " one of the `scheduler.timesteps` as a timestep."
                ),
hlky's avatar
hlky committed
383
384
385
            )

        if not self.is_scale_input_called:
386
            logger.warning(
hlky's avatar
hlky committed
387
388
389
390
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

YiYi Xu's avatar
YiYi Xu committed
391
392
        if self.step_index is None:
            self._init_step_index(timestep)
hlky's avatar
hlky committed
393

YiYi Xu's avatar
YiYi Xu committed
394
        sigma = self.sigmas[self.step_index]
hlky's avatar
hlky committed
395
396
397

        gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0

398
399
400
        noise = randn_tensor(
            model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
        )
401

hlky's avatar
hlky committed
402
403
404
405
406
407
408
        eps = noise * s_noise
        sigma_hat = sigma * (gamma + 1)

        if gamma > 0:
            sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
409
410
411
        # NOTE: "original_sample" should not be an expected prediction_type but is left in for
        # backwards compatibility
        if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
412
413
            pred_original_sample = model_output
        elif self.config.prediction_type == "epsilon":
Suraj Patil's avatar
Suraj Patil committed
414
            pred_original_sample = sample - sigma_hat * model_output
415
        elif self.config.prediction_type == "v_prediction":
Suraj Patil's avatar
Suraj Patil committed
416
417
418
419
            # * c_out + input * c_skip
            pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
        else:
            raise ValueError(
420
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
Suraj Patil's avatar
Suraj Patil committed
421
            )
hlky's avatar
hlky committed
422
423
424
425

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma_hat

YiYi Xu's avatar
YiYi Xu committed
426
        dt = self.sigmas[self.step_index + 1] - sigma_hat
hlky's avatar
hlky committed
427
428
429

        prev_sample = sample + derivative * dt

YiYi Xu's avatar
YiYi Xu committed
430
431
432
        # upon completion increase step index by one
        self._step_index += 1

hlky's avatar
hlky committed
433
434
435
436
437
438
439
440
441
442
443
444
        if not return_dict:
            return (prev_sample,)

        return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

    def add_noise(
        self,
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.FloatTensor,
    ) -> torch.FloatTensor:
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
445
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
hlky's avatar
hlky committed
446
447
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
448
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
hlky's avatar
hlky committed
449
450
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
451
            schedule_timesteps = self.timesteps.to(original_samples.device)
hlky's avatar
hlky committed
452
453
            timesteps = timesteps.to(original_samples.device)

Anton Lozhkov's avatar
Anton Lozhkov committed
454
        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
hlky's avatar
hlky committed
455

456
        sigma = sigmas[step_indices].flatten()
hlky's avatar
hlky committed
457
458
459
460
461
462
463
464
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps