"CODE_OF_CONDUCT.md" did not exist on "ad08b8ce131bacd6f61dfcd49e5f1af3cac76ca7"
scheduling_ddim.py 14.4 KB
Newer Older
1
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16
17

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

Patrick von Platen's avatar
Patrick von Platen committed
18
import math
19
from dataclasses import dataclass
20
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
21

Patrick von Platen's avatar
Patrick von Platen committed
22
import numpy as np
23
import torch
Patrick von Platen's avatar
Patrick von Platen committed
24

25
from ..configuration_utils import ConfigMixin, register_to_config
26
from ..utils import BaseOutput, deprecate
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from .scheduling_utils import SchedulerMixin


@dataclass
class DDIMSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
46
47


48
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
49
    """
Patrick von Platen's avatar
Patrick von Platen committed
50
51
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
52

53
54
55
56
57
58
59
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
60
                     prevent singularities.
61
62
63

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
64
    """
65

66
    def alpha_bar(time_step):
67
68
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

69
70
71
72
73
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
74
    return torch.tensor(betas)
Patrick von Platen's avatar
Patrick von Platen committed
75
76


Patrick von Platen's avatar
Patrick von Platen committed
77
class DDIMScheduler(SchedulerMixin, ConfigMixin):
78
79
80
81
    """
    Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
    diffusion probabilistic models (DDPMs) with non-Markovian guidance.

82
83
84
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
85
    [`~ConfigMixin.from_config`] functions.
86

87
88
89
90
91
92
93
94
95
    For more details, see the original paper: https://arxiv.org/abs/2010.02502

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
96
97
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
98
99
100
        clip_sample (`bool`, default `True`):
            option to clip predicted sample between -1 and 1 for numerical stability.
        set_alpha_to_one (`bool`, default `True`):
101
102
103
104
105
106
107
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        steps_offset (`int`, default `0`):
            an offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
108
109
110

    """

111
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
112
113
    def __init__(
        self,
114
115
116
117
118
119
120
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
        clip_sample: bool = True,
        set_alpha_to_one: bool = True,
121
        steps_offset: int = 0,
122
        **kwargs,
Patrick von Platen's avatar
Patrick von Platen committed
123
    ):
124
125
        deprecate(
            "tensor_format",
Patrick von Platen's avatar
Patrick von Platen committed
126
            "0.6.0",
127
128
129
            "If you're running your code in PyTorch, you can safely remove this argument.",
            take_from=kwargs,
        )
130

131
        if trained_betas is not None:
132
            self.betas = torch.from_numpy(trained_betas)
133
        elif beta_schedule == "linear":
134
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
135
136
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
137
138
139
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
Patrick von Platen's avatar
Patrick von Platen committed
140
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
141
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
142
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
143
144
145
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

146
        self.alphas = 1.0 - self.betas
147
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
148
149
150

        # At every step in ddim, we are looking into the previous alphas_cumprod
        # For the final step, there is no previous alphas_cumprod because we are already at 0
151
        # `set_alpha_to_one` decides whether we set this parameter simply to one or
152
        # whether we use the final alpha of the "non-previous" one.
153
        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
Patrick von Platen's avatar
Patrick von Platen committed
154

155
156
157
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

158
        # setable values
159
        self.num_inference_steps = None
160
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
Patrick von Platen's avatar
Patrick von Platen committed
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

176
177
    def _get_variance(self, timestep, prev_timestep):
        alpha_prod_t = self.alphas_cumprod[timestep]
178
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
179
180
181
182
183
184
185
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

186
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
187
188
189
190
191
192
193
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
194
        deprecated_offset = deprecate(
Patrick von Platen's avatar
Patrick von Platen committed
195
            "offset", "0.7.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
196
197
        )
        offset = deprecated_offset or self.config.steps_offset
198

199
        self.num_inference_steps = num_inference_steps
200
201
202
        step_ratio = self.config.num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
203
204
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps).to(device)
205
        self.timesteps += offset
206
207
208

    def step(
        self,
209
        model_output: torch.FloatTensor,
210
        timestep: int,
211
        sample: torch.FloatTensor,
Patrick von Platen's avatar
Patrick von Platen committed
212
213
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
214
        generator=None,
215
        return_dict: bool = True,
216
    ) -> Union[DDIMSchedulerOutput, Tuple]:
217
218
219
220
221
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
222
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
223
            timestep (`int`): current discrete timestep in the diffusion chain.
224
            sample (`torch.FloatTensor`):
225
226
227
228
                current instance of sample being created by diffusion process.
            eta (`float`): weight of noise for added noise in diffusion step.
            use_clipped_model_output (`bool`): TODO
            generator: random number generator.
229
            return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
230
231

        Returns:
232
233
            [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
234
            returning a tuple, the first element is the sample tensor.
235
236

        """
237
238
239
240
241
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
242
243
244
245
246
        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
247
        # - pred_original_sample -> f_theta(x_t, t) or x_0
Patrick von Platen's avatar
Patrick von Platen committed
248
249
        # - std_dev_t -> sigma_t
        # - eta -> η
250
        # - pred_sample_direction -> "direction pointing to x_t"
251
        # - pred_prev_sample -> "x_t-1"
Patrick von Platen's avatar
Patrick von Platen committed
252

253
        # 1. get previous step value (=t-1)
Nathan Lambert's avatar
Nathan Lambert committed
254
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
255
256

        # 2. compute alphas, betas
257
        alpha_prod_t = self.alphas_cumprod[timestep]
258
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
259

Patrick von Platen's avatar
Patrick von Platen committed
260
261
        beta_prod_t = 1 - alpha_prod_t

262
        # 3. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
263
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
264
        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
265
266

        # 4. Clip "predicted x_0"
267
        if self.config.clip_sample:
268
            pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
269
270
271

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
272
        variance = self._get_variance(timestep, prev_timestep)
Patrick von Platen's avatar
Patrick von Platen committed
273
        std_dev_t = eta * variance ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
274

Patrick von Platen's avatar
Patrick von Platen committed
275
276
277
        if use_clipped_model_output:
            # the model_output is always re-derived from the clipped x_0 in Glide
            model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
anton-l's avatar
anton-l committed
278

Patrick von Platen's avatar
Patrick von Platen committed
279
        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
280
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
Patrick von Platen's avatar
Patrick von Platen committed
281
282

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
283
284
285
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
286
            # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
Patrick von Platen's avatar
Patrick von Platen committed
287
            device = model_output.device if torch.is_tensor(model_output) else "cpu"
288
            noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
289
290
291
            variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise

            prev_sample = prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
292

293
294
295
        if not return_dict:
            return (prev_sample,)

296
        return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
Patrick von Platen's avatar
Patrick von Platen committed
297

298
299
    def add_noise(
        self,
300
301
302
303
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
304
305
306
307
308
309
        if self.alphas_cumprod.device != original_samples.device:
            self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)

        if timesteps.device != original_samples.device:
            timesteps = timesteps.to(original_samples.device)

310
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
311
312
313
314
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

315
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
316
317
318
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
319
320
321
322

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

Patrick von Platen's avatar
Patrick von Platen committed
323
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
324
        return self.config.num_train_timesteps