scheduling_lms_discrete.py 20.1 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import math
15
import warnings
16
from dataclasses import dataclass
17
from typing import List, Optional, Tuple, Union
18
19
20
21
22
23

import numpy as np
import torch
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
Kashif Rasul's avatar
Kashif Rasul committed
24
25
from ..utils import BaseOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26
27
28


@dataclass
29
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
30
31
class LMSDiscreteSchedulerOutput(BaseOutput):
    """
32
    Output class for the scheduler's `step` function output.
33
34
35

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
37
38
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
40
41
42
43
44
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
45
46


47
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
48
49
50
51
52
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
53
54
55
56
57
58
59
60
61
62
63
64
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
65
66
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
67
68
69
70

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
YiYi Xu's avatar
YiYi Xu committed
71
    if alpha_transform_type == "cosine":
72

YiYi Xu's avatar
YiYi Xu committed
73
74
75
76
77
78
79
80
81
82
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
        raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
83
84
85
86
87

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
88
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
89
90
91
    return torch.tensor(betas, dtype=torch.float32)


92
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
93
    """
94
    A linear multistep scheduler for discrete beta schedules.
95

96
97
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
98

99
    Args:
100
101
102
103
104
105
106
107
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
108
            `linear` or `scaled_linear`.
109
110
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
111
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
112
113
114
115
116
117
118
119
120
121
122
123
124
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
            An offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
            Diffusion.
125
126
    """

Kashif Rasul's avatar
Kashif Rasul committed
127
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
128
    order = 1
129

130
131
132
    @register_to_config
    def __init__(
        self,
133
134
135
136
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
137
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
138
        use_karras_sigmas: Optional[bool] = False,
139
        prediction_type: str = "epsilon",
140
141
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
142
    ):
143
        if trained_betas is not None:
144
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
145
        elif beta_schedule == "linear":
146
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
147
148
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
149
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
150
151
152
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
153
154
155
156
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
157
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
158

159
160
161
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
        self.sigmas = torch.from_numpy(sigmas)
162
163
164

        # setable values
        self.num_inference_steps = None
165
166
        self.use_karras_sigmas = use_karras_sigmas
        self.set_timesteps(num_train_timesteps, None)
167
        self.derivatives = []
168
169
        self.is_scale_input_called = False

YiYi Xu's avatar
YiYi Xu committed
170
        self._step_index = None
171
        self._begin_index = None
172
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
YiYi Xu's avatar
YiYi Xu committed
173

174
175
176
177
178
179
180
181
    @property
    def init_noise_sigma(self):
        # standard deviation of the initial noise distribution
        if self.config.timestep_spacing in ["linspace", "trailing"]:
            return self.sigmas.max()

        return (self.sigmas.max() ** 2 + 1) ** 0.5

YiYi Xu's avatar
YiYi Xu committed
182
183
184
185
186
187
188
    @property
    def step_index(self):
        """
        The index counter for current timestep. It will increae 1 after each scheduler step.
        """
        return self._step_index

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

207
208
209
210
    def scale_model_input(
        self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
    ) -> torch.FloatTensor:
        """
211
212
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.
213
214

        Args:
215
216
217
218
            sample (`torch.FloatTensor`):
                The input sample.
            timestep (`float` or `torch.FloatTensor`):
                The current timestep in the diffusion chain.
219
220

        Returns:
221
222
            `torch.FloatTensor`:
                A scaled input sample.
223
        """
YiYi Xu's avatar
YiYi Xu committed
224
225
226
227
228

        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]
229
230
231
        sample = sample / ((sigma**2 + 1) ** 0.5)
        self.is_scale_input_called = True
        return sample
232
233
234

    def get_lms_coefficient(self, order, t, current_order):
        """
235
        Compute the linear multistep coefficient.
236
237

        Args:
238
239
240
            order ():
            t ():
            current_order ():
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        """

        def lms_derivative(tau):
            prod = 1.0
            for k in range(order):
                if current_order == k:
                    continue
                prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
            return prod

        integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]

        return integrated_coeff

255
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
256
        """
257
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
258
259
260

        Args:
            num_inference_steps (`int`):
261
262
263
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
264
        """
265
266
        self.num_inference_steps = num_inference_steps

267
268
        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
        if self.config.timestep_spacing == "linspace":
YiYi Xu's avatar
YiYi Xu committed
269
            timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
270
271
272
273
274
275
                ::-1
            ].copy()
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
276
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
277
278
279
280
281
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
YiYi Xu's avatar
YiYi Xu committed
282
            timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
283
284
285
286
287
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
288

289
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
290
        log_sigmas = np.log(sigmas)
291
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
292
293
294
295
296

        if self.use_karras_sigmas:
            sigmas = self._convert_to_karras(in_sigmas=sigmas)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

297
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
298

299
        self.sigmas = torch.from_numpy(sigmas).to(device=device)
YiYi Xu's avatar
YiYi Xu committed
300
301
        self.timesteps = torch.from_numpy(timesteps).to(device=device)
        self._step_index = None
302
        self._begin_index = None
303
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication
304
305
306

        self.derivatives = []

307
308
309
310
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
YiYi Xu's avatar
YiYi Xu committed
311

312
        indices = (schedule_timesteps == timestep).nonzero()
YiYi Xu's avatar
YiYi Xu committed
313
314
315
316
317

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
318
319
320
        pos = 1 if len(indices) > 1 else 0

        return indices[pos].item()
YiYi Xu's avatar
YiYi Xu committed
321

322
323
324
325
326
327
328
329
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
    def _init_step_index(self, timestep):
        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
YiYi Xu's avatar
YiYi Xu committed
330

331
332
333
    # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
334
        log_sigma = np.log(np.maximum(sigma, 1e-10))
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
    def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
        """Constructs the noise schedule of Karras et al. (2022)."""

        sigma_min: float = in_sigmas[-1].item()
        sigma_max: float = in_sigmas[0].item()

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, self.num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

369
370
    def step(
        self,
371
        model_output: torch.FloatTensor,
372
        timestep: Union[float, torch.FloatTensor],
373
        sample: torch.FloatTensor,
374
        order: int = 4,
375
        return_dict: bool = True,
376
    ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
377
        """
378
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
379
380
381
        process from the learned model outputs (most often the predicted noise).

        Args:
382
383
384
385
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`float` or `torch.FloatTensor`):
                The current discrete timestep in the diffusion chain.
386
            sample (`torch.FloatTensor`):
387
388
389
390
391
                A current instance of a sample created by the diffusion process.
            order (`int`, defaults to 4):
                The order of the linear multistep method.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
392
393

        Returns:
394
395
396
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
397
398

        """
399
400
401
402
403
404
        if not self.is_scale_input_called:
            warnings.warn(
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

YiYi Xu's avatar
YiYi Xu committed
405
406
407
408
        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]
409
410

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
411
412
413
414
415
        if self.config.prediction_type == "epsilon":
            pred_original_sample = sample - sigma * model_output
        elif self.config.prediction_type == "v_prediction":
            # * c_out + input * c_skip
            pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
416
417
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
418
419
420
421
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )
422
423
424
425
426
427
428
429

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma
        self.derivatives.append(derivative)
        if len(self.derivatives) > order:
            self.derivatives.pop(0)

        # 3. Compute linear multistep coefficients
YiYi Xu's avatar
YiYi Xu committed
430
431
        order = min(self.step_index + 1, order)
        lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)]
432
433
434
435
436
437

        # 4. Compute previous sample based on the derivatives path
        prev_sample = sample + sum(
            coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
        )

YiYi Xu's avatar
YiYi Xu committed
438
439
440
        # upon completion increase step index by one
        self._step_index += 1

441
442
443
        if not return_dict:
            return (prev_sample,)

444
        return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
445

446
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
447
448
    def add_noise(
        self,
449
450
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
451
        timesteps: torch.FloatTensor,
452
    ) -> torch.FloatTensor:
453
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
454
        sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
455
456
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
457
            schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
458
459
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
460
            schedule_timesteps = self.timesteps.to(original_samples.device)
461
            timesteps = timesteps.to(original_samples.device)
462

463
464
465
466
467
        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
        else:
            step_indices = [self.begin_index] * timesteps.shape[0]
468

469
        sigma = sigmas[step_indices].flatten()
470
471
472
473
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
474
475
476
477
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps