scheduling_ddim.py 19.5 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16
17

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

Patrick von Platen's avatar
Patrick von Platen committed
18
import math
19
from dataclasses import dataclass
20
from typing import List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
21

Patrick von Platen's avatar
Patrick von Platen committed
22
import numpy as np
23
import torch
Patrick von Platen's avatar
Patrick von Platen committed
24

25
from ..configuration_utils import ConfigMixin, register_to_config
26
from ..utils import BaseOutput, randn_tensor
Kashif Rasul's avatar
Kashif Rasul committed
27
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
28
29
30


@dataclass
31
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class DDIMSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
47
48


49
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
50
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
51
    """
Patrick von Platen's avatar
Patrick von Platen committed
52
53
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
54

55
56
57
58
59
60
61
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
62
                     prevent singularities.
63
64
65

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
66
    """
67

68
    def alpha_bar(time_step):
69
70
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

71
72
73
74
75
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
76
    return torch.tensor(betas, dtype=torch.float32)
Patrick von Platen's avatar
Patrick von Platen committed
77
78


Patrick von Platen's avatar
Patrick von Platen committed
79
class DDIMScheduler(SchedulerMixin, ConfigMixin):
80
81
82
83
    """
    Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
    diffusion probabilistic models (DDPMs) with non-Markovian guidance.

84
85
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
86
87
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.
88

89
90
91
92
93
94
95
96
97
    For more details, see the original paper: https://arxiv.org/abs/2010.02502

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
98
99
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
100
        clip_sample (`bool`, default `True`):
101
102
103
            option to clip predicted sample for numerical stability.
        clip_sample_range (`float`, default `1.0`):
            the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
104
        set_alpha_to_one (`bool`, default `True`):
105
106
107
108
109
110
111
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        steps_offset (`int`, default `0`):
            an offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
112
113
114
115
        prediction_type (`str`, default `epsilon`, optional):
            prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
            process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
            https://imagen.research.google/video/paper.pdf)
116
117
118
119
120
121
122
123
124
        thresholding (`bool`, default `False`):
            whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
            Note that the thresholding method is unsuitable for latent-space diffusion models (such as
            stable-diffusion).
        dynamic_thresholding_ratio (`float`, default `0.995`):
            the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
            (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
        sample_max_value (`float`, default `1.0`):
            the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
125
126
    """

Kashif Rasul's avatar
Kashif Rasul committed
127
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
128
    order = 1
129

130
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
131
132
    def __init__(
        self,
133
134
135
136
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
137
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
138
139
        clip_sample: bool = True,
        set_alpha_to_one: bool = True,
140
        steps_offset: int = 0,
Suraj Patil's avatar
Suraj Patil committed
141
        prediction_type: str = "epsilon",
142
143
144
145
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        clip_sample_range: float = 1.0,
        sample_max_value: float = 1.0,
Patrick von Platen's avatar
Patrick von Platen committed
146
    ):
147
        if trained_betas is not None:
148
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
149
        elif beta_schedule == "linear":
150
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
151
152
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
153
154
155
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
Patrick von Platen's avatar
Patrick von Platen committed
156
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
157
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
158
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
159
160
161
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

162
        self.alphas = 1.0 - self.betas
163
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
164
165
166

        # At every step in ddim, we are looking into the previous alphas_cumprod
        # For the final step, there is no previous alphas_cumprod because we are already at 0
167
        # `set_alpha_to_one` decides whether we set this parameter simply to one or
168
        # whether we use the final alpha of the "non-previous" one.
169
        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
Patrick von Platen's avatar
Patrick von Platen committed
170

171
172
173
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

174
        # setable values
175
        self.num_inference_steps = None
176
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
Patrick von Platen's avatar
Patrick von Platen committed
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

192
193
    def _get_variance(self, timestep, prev_timestep):
        alpha_prod_t = self.alphas_cumprod[timestep]
194
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
195
196
197
198
199
200
201
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

202
203
204
205
206
207
208
209
210
211
212
213
    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
    def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
        # Dynamic thresholding in https://arxiv.org/abs/2205.11487
        dynamic_max_val = (
            sample.flatten(1)
            .abs()
            .quantile(self.config.dynamic_thresholding_ratio, dim=1)
            .clamp_min(self.config.sample_max_value)
            .view(-1, *([1] * (sample.ndim - 1)))
        )
        return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val

214
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
215
216
217
218
219
220
221
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
222
223
224
225
226
227
228
229

        if num_inference_steps > self.config.num_train_timesteps:
            raise ValueError(
                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
                f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
                f" maximal {self.config.num_train_timesteps} timesteps."
            )

230
        self.num_inference_steps = num_inference_steps
231
232
233
        step_ratio = self.config.num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
234
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
235
        self.timesteps = torch.from_numpy(timesteps).to(device)
236
        self.timesteps += self.config.steps_offset
237
238
239

    def step(
        self,
240
        model_output: torch.FloatTensor,
241
        timestep: int,
242
        sample: torch.FloatTensor,
Patrick von Platen's avatar
Patrick von Platen committed
243
244
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
245
        generator=None,
246
        variance_noise: Optional[torch.FloatTensor] = None,
247
        return_dict: bool = True,
248
    ) -> Union[DDIMSchedulerOutput, Tuple]:
249
250
251
252
253
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
254
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
255
            timestep (`int`): current discrete timestep in the diffusion chain.
256
            sample (`torch.FloatTensor`):
257
258
                current instance of sample being created by diffusion process.
            eta (`float`): weight of noise for added noise in diffusion step.
259
260
261
262
            use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
                predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
                `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
                coincide with the one provided as input and `use_clipped_model_output` will have not effect.
263
            generator: random number generator.
264
265
266
            variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
                can directly provide the noise for the variance itself. This is useful for methods such as
                CycleDiffusion. (https://arxiv.org/abs/2210.05559)
267
            return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
268
269

        Returns:
270
271
            [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
272
            returning a tuple, the first element is the sample tensor.
273
274

        """
275
276
277
278
279
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
280
281
282
283
284
        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
285
        # - pred_original_sample -> f_theta(x_t, t) or x_0
Patrick von Platen's avatar
Patrick von Platen committed
286
287
        # - std_dev_t -> sigma_t
        # - eta -> η
288
        # - pred_sample_direction -> "direction pointing to x_t"
289
        # - pred_prev_sample -> "x_t-1"
Patrick von Platen's avatar
Patrick von Platen committed
290

291
        # 1. get previous step value (=t-1)
Nathan Lambert's avatar
Nathan Lambert committed
292
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
293
294

        # 2. compute alphas, betas
295
        alpha_prod_t = self.alphas_cumprod[timestep]
296
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
297

Patrick von Platen's avatar
Patrick von Platen committed
298
299
        beta_prod_t = 1 - alpha_prod_t

300
        # 3. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
301
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
302
        if self.config.prediction_type == "epsilon":
Suraj Patil's avatar
Suraj Patil committed
303
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
304
            pred_epsilon = model_output
305
        elif self.config.prediction_type == "sample":
Suraj Patil's avatar
Suraj Patil committed
306
            pred_original_sample = model_output
307
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
308
        elif self.config.prediction_type == "v_prediction":
Suraj Patil's avatar
Suraj Patil committed
309
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
310
            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
Suraj Patil's avatar
Suraj Patil committed
311
312
        else:
            raise ValueError(
313
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
Suraj Patil's avatar
Suraj Patil committed
314
315
                " `v_prediction`"
            )
Patrick von Platen's avatar
Patrick von Platen committed
316

317
        # 4. Clip or threshold "predicted x_0"
318
        if self.config.clip_sample:
319
320
321
322
323
324
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )

        if self.config.thresholding:
            pred_original_sample = self._threshold_sample(pred_original_sample)
Patrick von Platen's avatar
Patrick von Platen committed
325
326
327

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
328
        variance = self._get_variance(timestep, prev_timestep)
Patrick von Platen's avatar
Patrick von Platen committed
329
        std_dev_t = eta * variance ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
330

Patrick von Platen's avatar
Patrick von Platen committed
331
        if use_clipped_model_output:
332
333
            # the pred_epsilon is always re-derived from the clipped x_0 in Glide
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
anton-l's avatar
anton-l committed
334

Patrick von Platen's avatar
Patrick von Platen committed
335
        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
336
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
Patrick von Platen's avatar
Patrick von Platen committed
337
338

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
339
340
341
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
342
343
344
345
346
347
348
            if variance_noise is not None and generator is not None:
                raise ValueError(
                    "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
                    " `variance_noise` stays `None`."
                )

            if variance_noise is None:
349
                variance_noise = randn_tensor(
350
                    model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
351
                )
352
            variance = std_dev_t * variance_noise
353
354

            prev_sample = prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
355

356
357
358
        if not return_dict:
            return (prev_sample,)

359
        return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
Patrick von Platen's avatar
Patrick von Platen committed
360

361
362
    def add_noise(
        self,
363
364
365
366
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
367
368
369
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
        self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
        timesteps = timesteps.to(original_samples.device)
370

371
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
372
373
374
375
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

376
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
377
378
379
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
380
381
382
383

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    def get_velocity(
        self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
    ) -> torch.FloatTensor:
        # Make sure alphas_cumprod and timestep have same device and dtype as sample
        self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
        timesteps = timesteps.to(sample.device)

        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(sample.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
        return velocity

Patrick von Platen's avatar
Patrick von Platen committed
404
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
405
        return self.config.num_train_timesteps