scheduling_ddpm.py 16.8 KB
Newer Older
Ryan Russell's avatar
Ryan Russell committed
1
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
improve  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

anton-l's avatar
anton-l committed
17
import math
18
from dataclasses import dataclass
19
from typing import List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
20

Patrick von Platen's avatar
Patrick von Platen committed
21
import numpy as np
22
import torch
Patrick von Platen's avatar
improve  
Patrick von Platen committed
23

24
25
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor
Kashif Rasul's avatar
Kashif Rasul committed
26
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


@dataclass
class DDPMSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
45
46
47
48


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
49
50
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
51

52
53
54
55
56
57
58
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
59
                     prevent singularities.
60
61
62

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
63
    """
64

65
66
67
68
69
70
71
72
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
73
    return torch.tensor(betas, dtype=torch.float32)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
74
75


Patrick von Platen's avatar
Patrick von Platen committed
76
class DDPMScheduler(SchedulerMixin, ConfigMixin):
77
78
79
80
    """
    Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
    Langevin dynamics sampling.

81
82
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
83
84
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.
85

86
87
88
89
90
91
92
93
94
    For more details, see the original paper: https://arxiv.org/abs/2006.11239

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
95
96
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
97
98
99
100
101
        variance_type (`str`):
            options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
            `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
        clip_sample (`bool`, default `True`):
            option to clip predicted sample between -1 and 1 for numerical stability.
102
103
104
105
        prediction_type (`str`, default `epsilon`, optional):
            prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
            process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
            https://imagen.research.google/video/paper.pdf)
106
107
    """

Kashif Rasul's avatar
Kashif Rasul committed
108
    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
109
    order = 1
110

111
    @register_to_config
Patrick von Platen's avatar
improve  
Patrick von Platen committed
112
113
    def __init__(
        self,
Partho's avatar
Partho committed
114
115
116
117
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
118
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
Partho's avatar
Partho committed
119
120
        variance_type: str = "fixed_small",
        clip_sample: bool = True,
121
        prediction_type: str = "epsilon",
Will Berman's avatar
Will Berman committed
122
        clip_sample_range: Optional[float] = 1.0,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
123
    ):
124
        if trained_betas is not None:
125
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
126
        elif beta_schedule == "linear":
127
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
128
129
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
130
131
132
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
anton-l's avatar
anton-l committed
133
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
134
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
135
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Nathan Lambert's avatar
Nathan Lambert committed
136
137
138
139
        elif beta_schedule == "sigmoid":
            # GeoDiff sigmoid schedule
            betas = torch.linspace(-6, 6, num_train_timesteps)
            self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
Patrick von Platen's avatar
improve  
Patrick von Platen committed
140
141
142
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Patrick von Platen's avatar
Patrick von Platen committed
143
        self.alphas = 1.0 - self.betas
144
145
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.one = torch.tensor(1.0)
Patrick von Platen's avatar
Patrick von Platen committed
146

147
148
149
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

150
151
        # setable values
        self.num_inference_steps = None
152
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
Patrick von Platen's avatar
Patrick von Platen committed
153

154
155
        self.variance_type = variance_type

156
157
158
159
160
161
162
163
164
165
166
167
168
169
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

170
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
171
172
173
174
175
176
177
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
178
179
180
181
182
183
184
185

        if num_inference_steps > self.config.num_train_timesteps:
            raise ValueError(
                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
                f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
                f" maximal {self.config.num_train_timesteps} timesteps."
            )

186
        self.num_inference_steps = num_inference_steps
187
188
189

        step_ratio = self.config.num_train_timesteps // self.num_inference_steps
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
190
        self.timesteps = torch.from_numpy(timesteps).to(device)
191

192
    def _get_variance(self, t, predicted_variance=None, variance_type=None):
193
194
        num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
        prev_t = t - self.config.num_train_timesteps // num_inference_steps
195
        alpha_prod_t = self.alphas_cumprod[t]
196
197
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
Patrick von Platen's avatar
Patrick von Platen committed
198

Kashif Rasul's avatar
Kashif Rasul committed
199
        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
200
        # and sample from it to get previous sample
Kashif Rasul's avatar
Kashif Rasul committed
201
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
202
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
Patrick von Platen's avatar
Patrick von Platen committed
203

204
205
206
        if variance_type is None:
            variance_type = self.config.variance_type

207
        # hacks - were probably added for training stability
208
        if variance_type == "fixed_small":
209
            variance = torch.clamp(variance, min=1e-20)
210
        # for rl-diffuser https://arxiv.org/abs/2205.09991
211
        elif variance_type == "fixed_small_log":
212
            variance = torch.log(torch.clamp(variance, min=1e-20))
213
            variance = torch.exp(0.5 * variance)
214
        elif variance_type == "fixed_large":
215
            variance = current_beta_t
216
        elif variance_type == "fixed_large_log":
Patrick von Platen's avatar
Patrick von Platen committed
217
            # Glide max_log
218
            variance = torch.log(current_beta_t)
219
220
221
        elif variance_type == "learned":
            return predicted_variance
        elif variance_type == "learned_range":
222
223
            min_log = torch.log(variance)
            max_log = torch.log(self.betas[t])
224
225
            frac = (predicted_variance + 1) / 2
            variance = frac * max_log + (1 - frac) * min_log
Patrick von Platen's avatar
Patrick von Platen committed
226
227
228

        return variance

229
230
    def step(
        self,
231
        model_output: torch.FloatTensor,
232
        timestep: int,
233
        sample: torch.FloatTensor,
Patrick von Platen's avatar
Patrick von Platen committed
234
        generator=None,
235
        return_dict: bool = True,
236
    ) -> Union[DDPMSchedulerOutput, Tuple]:
237
238
239
240
241
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
242
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
243
            timestep (`int`): current discrete timestep in the diffusion chain.
244
            sample (`torch.FloatTensor`):
245
246
                current instance of sample being created by diffusion process.
            generator: random number generator.
247
            return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
248
249

        Returns:
250
251
            [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
252
            returning a tuple, the first element is the sample tensor.
253
254

        """
255
        t = timestep
256
257
        num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
        prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
258

259
260
261
262
263
        if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
            model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
        else:
            predicted_variance = None

Patrick von Platen's avatar
Patrick von Platen committed
264
        # 1. compute alphas, betas
265
        alpha_prod_t = self.alphas_cumprod[t]
266
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
267
268
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev
269
270
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        current_beta_t = 1 - current_alpha_t
Patrick von Platen's avatar
Patrick von Platen committed
271

272
        # 2. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
273
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
274
        if self.config.prediction_type == "epsilon":
Patrick von Platen's avatar
Patrick von Platen committed
275
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
276
        elif self.config.prediction_type == "sample":
Patrick von Platen's avatar
Patrick von Platen committed
277
            pred_original_sample = model_output
278
279
        elif self.config.prediction_type == "v_prediction":
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
280
281
        else:
            raise ValueError(
282
283
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
                " `v_prediction`  for the DDPMScheduler."
284
            )
Patrick von Platen's avatar
Patrick von Platen committed
285
286

        # 3. Clip "predicted x_0"
287
        if self.config.clip_sample:
Will Berman's avatar
Will Berman committed
288
289
290
            pred_original_sample = torch.clamp(
                pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
            )
Patrick von Platen's avatar
Patrick von Platen committed
291

292
        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
Patrick von Platen's avatar
Patrick von Platen committed
293
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
294
295
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
        current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
Patrick von Platen's avatar
Patrick von Platen committed
296

297
        # 5. Compute predicted previous sample µ_t
Patrick von Platen's avatar
Patrick von Platen committed
298
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
299
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
Patrick von Platen's avatar
Patrick von Platen committed
300

Patrick von Platen's avatar
Patrick von Platen committed
301
302
303
        # 6. Add noise
        variance = 0
        if t > 0:
304
            device = model_output.device
305
306
307
            variance_noise = randn_tensor(
                model_output.shape, generator=generator, device=device, dtype=model_output.dtype
            )
308
309
            if self.variance_type == "fixed_small_log":
                variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
310
311
312
            elif self.variance_type == "learned_range":
                variance = self._get_variance(t, predicted_variance=predicted_variance)
                variance = torch.exp(0.5 * variance) * variance_noise
313
314
            else:
                variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
Patrick von Platen's avatar
Patrick von Platen committed
315
316
317

        pred_prev_sample = pred_prev_sample + variance

318
319
320
        if not return_dict:
            return (pred_prev_sample,)

321
        return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
Patrick von Platen's avatar
Patrick von Platen committed
322

Partho's avatar
Partho committed
323
324
    def add_noise(
        self,
325
326
327
328
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
329
330
331
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
        self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
        timesteps = timesteps.to(original_samples.device)
332

anton-l's avatar
anton-l committed
333
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
334
335
336
337
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

anton-l's avatar
anton-l committed
338
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
339
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
340
341
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
342
343

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
anton-l's avatar
anton-l committed
344
        return noisy_samples
anton-l's avatar
anton-l committed
345

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    def get_velocity(
        self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
    ) -> torch.FloatTensor:
        # Make sure alphas_cumprod and timestep have same device and dtype as sample
        self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
        timesteps = timesteps.to(sample.device)

        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(sample.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
        return velocity

Patrick von Platen's avatar
improve  
Patrick von Platen committed
366
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
367
        return self.config.num_train_timesteps