scheduling_ddim.py 15.4 KB
Newer Older
1
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16
17

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

Patrick von Platen's avatar
Patrick von Platen committed
18
import math
19
from dataclasses import dataclass
20
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
21

Patrick von Platen's avatar
Patrick von Platen committed
22
import numpy as np
23
import torch
Patrick von Platen's avatar
Patrick von Platen committed
24

25
from ..configuration_utils import ConfigMixin, register_to_config
26
from ..utils import BaseOutput
27
28
29
30
from .scheduling_utils import SchedulerMixin


@dataclass
31
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class DDIMSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
47
48


49
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
50
    """
Patrick von Platen's avatar
Patrick von Platen committed
51
52
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
53

54
55
56
57
58
59
60
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
61
                     prevent singularities.
62
63
64

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
65
    """
66

67
    def alpha_bar(time_step):
68
69
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

70
71
72
73
74
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
75
    return torch.tensor(betas)
Patrick von Platen's avatar
Patrick von Platen committed
76
77


Patrick von Platen's avatar
Patrick von Platen committed
78
class DDIMScheduler(SchedulerMixin, ConfigMixin):
79
80
81
82
    """
    Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
    diffusion probabilistic models (DDPMs) with non-Markovian guidance.

83
84
85
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
86
    [`~ConfigMixin.from_config`] functions.
87

88
89
90
91
92
93
94
95
96
    For more details, see the original paper: https://arxiv.org/abs/2010.02502

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
97
98
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
99
100
101
        clip_sample (`bool`, default `True`):
            option to clip predicted sample between -1 and 1 for numerical stability.
        set_alpha_to_one (`bool`, default `True`):
102
103
104
105
106
107
108
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        steps_offset (`int`, default `0`):
            an offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
109
110
111

    """

112
113
114
115
116
117
118
119
    _compatible_classes = [
        "PNDMScheduler",
        "DDPMScheduler",
        "LMSDiscreteScheduler",
        "EulerDiscreteScheduler",
        "EulerAncestralDiscreteScheduler",
    ]

120
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
121
122
    def __init__(
        self,
123
124
125
126
127
128
129
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
        clip_sample: bool = True,
        set_alpha_to_one: bool = True,
130
        steps_offset: int = 0,
Patrick von Platen's avatar
Patrick von Platen committed
131
    ):
132
        if trained_betas is not None:
133
            self.betas = torch.from_numpy(trained_betas)
134
        elif beta_schedule == "linear":
135
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
136
137
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
138
139
140
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
Patrick von Platen's avatar
Patrick von Platen committed
141
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
142
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
143
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
144
145
146
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

147
        self.alphas = 1.0 - self.betas
148
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
149
150
151

        # At every step in ddim, we are looking into the previous alphas_cumprod
        # For the final step, there is no previous alphas_cumprod because we are already at 0
152
        # `set_alpha_to_one` decides whether we set this parameter simply to one or
153
        # whether we use the final alpha of the "non-previous" one.
154
        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
Patrick von Platen's avatar
Patrick von Platen committed
155

156
157
158
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

159
        # setable values
160
        self.num_inference_steps = None
161
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
Patrick von Platen's avatar
Patrick von Platen committed
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

177
178
    def _get_variance(self, timestep, prev_timestep):
        alpha_prod_t = self.alphas_cumprod[timestep]
179
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
180
181
182
183
184
185
186
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

187
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
188
189
190
191
192
193
194
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
195
        self.num_inference_steps = num_inference_steps
196
197
198
        step_ratio = self.config.num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
199
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
200
        self.timesteps = torch.from_numpy(timesteps).to(device)
201
        self.timesteps += self.config.steps_offset
202
203
204

    def step(
        self,
205
        model_output: torch.FloatTensor,
206
        timestep: int,
207
        sample: torch.FloatTensor,
Patrick von Platen's avatar
Patrick von Platen committed
208
209
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
210
        generator=None,
211
        variance_noise: Optional[torch.FloatTensor] = None,
212
        return_dict: bool = True,
213
    ) -> Union[DDIMSchedulerOutput, Tuple]:
214
215
216
217
218
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
219
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
220
            timestep (`int`): current discrete timestep in the diffusion chain.
221
            sample (`torch.FloatTensor`):
222
223
                current instance of sample being created by diffusion process.
            eta (`float`): weight of noise for added noise in diffusion step.
224
225
226
227
            use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
                predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
                `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
                coincide with the one provided as input and `use_clipped_model_output` will have not effect.
228
            generator: random number generator.
229
230
231
            variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
                can directly provide the noise for the variance itself. This is useful for methods such as
                CycleDiffusion. (https://arxiv.org/abs/2210.05559)
232
            return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
233
234

        Returns:
235
236
            [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
237
            returning a tuple, the first element is the sample tensor.
238
239

        """
240
241
242
243
244
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
245
246
247
248
249
        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
250
        # - pred_original_sample -> f_theta(x_t, t) or x_0
Patrick von Platen's avatar
Patrick von Platen committed
251
252
        # - std_dev_t -> sigma_t
        # - eta -> η
253
        # - pred_sample_direction -> "direction pointing to x_t"
254
        # - pred_prev_sample -> "x_t-1"
Patrick von Platen's avatar
Patrick von Platen committed
255

256
        # 1. get previous step value (=t-1)
Nathan Lambert's avatar
Nathan Lambert committed
257
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
258
259

        # 2. compute alphas, betas
260
        alpha_prod_t = self.alphas_cumprod[timestep]
261
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
262

Patrick von Platen's avatar
Patrick von Platen committed
263
264
        beta_prod_t = 1 - alpha_prod_t

265
        # 3. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
266
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
267
        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
268
269

        # 4. Clip "predicted x_0"
270
        if self.config.clip_sample:
271
            pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
272
273
274

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
275
        variance = self._get_variance(timestep, prev_timestep)
Patrick von Platen's avatar
Patrick von Platen committed
276
        std_dev_t = eta * variance ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
277

Patrick von Platen's avatar
Patrick von Platen committed
278
279
280
        if use_clipped_model_output:
            # the model_output is always re-derived from the clipped x_0 in Glide
            model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
anton-l's avatar
anton-l committed
281

Patrick von Platen's avatar
Patrick von Platen committed
282
        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
283
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
Patrick von Platen's avatar
Patrick von Platen committed
284
285

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
286
287
288
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
289
            # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
Patrick von Platen's avatar
Patrick von Platen committed
290
            device = model_output.device if torch.is_tensor(model_output) else "cpu"
291
292
293
294
295
296
297
298
299
300
301
            if variance_noise is not None and generator is not None:
                raise ValueError(
                    "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
                    " `variance_noise` stays `None`."
                )

            if variance_noise is None:
                variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(
                    device
                )
            variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
302
303

            prev_sample = prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
304

305
306
307
        if not return_dict:
            return (prev_sample,)

308
        return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
Patrick von Platen's avatar
Patrick von Platen committed
309

310
311
    def add_noise(
        self,
312
313
314
315
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
316
317
318
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
        self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
        timesteps = timesteps.to(original_samples.device)
319

320
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
321
322
323
324
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

325
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
326
327
328
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
329
330
331
332

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

Patrick von Platen's avatar
Patrick von Platen committed
333
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
334
        return self.config.num_train_timesteps