scheduling_ddpm.py 15.5 KB
Newer Older
Ryan Russell's avatar
Ryan Russell committed
1
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
improve  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

anton-l's avatar
anton-l committed
17
import math
18
from dataclasses import dataclass
19
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
20

Patrick von Platen's avatar
Patrick von Platen committed
21
import numpy as np
22
import torch
Patrick von Platen's avatar
improve  
Patrick von Platen committed
23

24
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
25
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from .scheduling_utils import SchedulerMixin


@dataclass
class DDPMSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
45
46
47
48


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
49
50
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
51

52
53
54
55
56
57
58
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
59
                     prevent singularities.
60
61
62

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
63
    """
64

65
66
67
68
69
70
71
72
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
73
    return torch.tensor(betas, dtype=torch.float32)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
74
75


Patrick von Platen's avatar
Patrick von Platen committed
76
class DDPMScheduler(SchedulerMixin, ConfigMixin):
77
78
79
80
    """
    Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
    Langevin dynamics sampling.

81
82
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
83
84
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.
85

86
87
88
89
90
91
92
93
94
    For more details, see the original paper: https://arxiv.org/abs/2006.11239

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
95
96
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
97
98
99
100
101
        variance_type (`str`):
            options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
            `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
        clip_sample (`bool`, default `True`):
            option to clip predicted sample between -1 and 1 for numerical stability.
102
103
104
        prediction_type (`str`, default `epsilon`):
            indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
            `v-prediction` is not supported for this scheduler.
105
106
    """

107
    _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
108

109
    @register_to_config
Patrick von Platen's avatar
improve  
Patrick von Platen committed
110
111
    def __init__(
        self,
Partho's avatar
Partho committed
112
113
114
115
116
117
118
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
        variance_type: str = "fixed_small",
        clip_sample: bool = True,
119
120
        prediction_type: str = "epsilon",
        **kwargs,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
121
    ):
122
123
124
125
126
127
128
129
        message = (
            "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
            " DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
        )
        predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
        if predict_epsilon is not None:
            self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

130
        if trained_betas is not None:
131
            self.betas = torch.from_numpy(trained_betas)
132
        elif beta_schedule == "linear":
133
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
134
135
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
136
137
138
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
anton-l's avatar
anton-l committed
139
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
140
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
141
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Nathan Lambert's avatar
Nathan Lambert committed
142
143
144
145
        elif beta_schedule == "sigmoid":
            # GeoDiff sigmoid schedule
            betas = torch.linspace(-6, 6, num_train_timesteps)
            self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
Patrick von Platen's avatar
improve  
Patrick von Platen committed
146
147
148
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Patrick von Platen's avatar
Patrick von Platen committed
149
        self.alphas = 1.0 - self.betas
150
151
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.one = torch.tensor(1.0)
Patrick von Platen's avatar
Patrick von Platen committed
152

153
154
155
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

156
157
        # setable values
        self.num_inference_steps = None
158
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
Patrick von Platen's avatar
Patrick von Platen committed
159

160
161
        self.variance_type = variance_type

162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

176
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
177
178
179
180
181
182
183
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
Patrick von Platen's avatar
Patrick von Platen committed
184
        num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
185
        self.num_inference_steps = num_inference_steps
186
        timesteps = np.arange(
187
            0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
188
189
        )[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps).to(device)
190

191
    def _get_variance(self, t, predicted_variance=None, variance_type=None):
192
193
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
194

Kashif Rasul's avatar
Kashif Rasul committed
195
        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
196
        # and sample from it to get previous sample
Kashif Rasul's avatar
Kashif Rasul committed
197
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
198
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
Patrick von Platen's avatar
Patrick von Platen committed
199

200
201
202
        if variance_type is None:
            variance_type = self.config.variance_type

203
        # hacks - were probably added for training stability
204
        if variance_type == "fixed_small":
205
            variance = torch.clamp(variance, min=1e-20)
206
        # for rl-diffuser https://arxiv.org/abs/2205.09991
207
        elif variance_type == "fixed_small_log":
208
            variance = torch.log(torch.clamp(variance, min=1e-20))
209
            variance = torch.exp(0.5 * variance)
210
        elif variance_type == "fixed_large":
211
            variance = self.betas[t]
212
        elif variance_type == "fixed_large_log":
Patrick von Platen's avatar
Patrick von Platen committed
213
            # Glide max_log
214
            variance = torch.log(self.betas[t])
215
216
217
218
219
220
221
        elif variance_type == "learned":
            return predicted_variance
        elif variance_type == "learned_range":
            min_log = variance
            max_log = self.betas[t]
            frac = (predicted_variance + 1) / 2
            variance = frac * max_log + (1 - frac) * min_log
Patrick von Platen's avatar
Patrick von Platen committed
222
223
224

        return variance

225
226
    def step(
        self,
227
        model_output: torch.FloatTensor,
228
        timestep: int,
229
        sample: torch.FloatTensor,
Patrick von Platen's avatar
Patrick von Platen committed
230
        generator=None,
231
        return_dict: bool = True,
232
        **kwargs,
233
    ) -> Union[DDPMSchedulerOutput, Tuple]:
234
235
236
237
238
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
239
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
240
            timestep (`int`): current discrete timestep in the diffusion chain.
241
            sample (`torch.FloatTensor`):
242
243
                current instance of sample being created by diffusion process.
            generator: random number generator.
244
            return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
245
246

        Returns:
247
248
            [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
249
            returning a tuple, the first element is the sample tensor.
250
251

        """
252
        message = (
253
254
            "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
            " DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
255
256
        )
        predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
257
        if predict_epsilon is not None:
258
            new_config = dict(self.config)
259
            new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
260
261
            self._internal_dict = FrozenDict(new_config)

262
        t = timestep
263

264
265
266
267
268
        if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
            model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
        else:
            predicted_variance = None

Patrick von Platen's avatar
Patrick von Platen committed
269
        # 1. compute alphas, betas
270
271
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
272
273
274
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

275
        # 2. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
276
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
277
        if self.config.prediction_type == "epsilon":
Patrick von Platen's avatar
Patrick von Platen committed
278
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
279
        elif self.config.prediction_type == "sample":
Patrick von Platen's avatar
Patrick von Platen committed
280
            pred_original_sample = model_output
281
282
283
284
285
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
                " for the DDPMScheduler."
            )
Patrick von Platen's avatar
Patrick von Platen committed
286
287

        # 3. Clip "predicted x_0"
288
        if self.config.clip_sample:
289
            pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
290

291
        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
Patrick von Platen's avatar
Patrick von Platen committed
292
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
293
294
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
        current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
Patrick von Platen's avatar
Patrick von Platen committed
295

296
        # 5. Compute predicted previous sample µ_t
Patrick von Platen's avatar
Patrick von Platen committed
297
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
298
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
Patrick von Platen's avatar
Patrick von Platen committed
299

Patrick von Platen's avatar
Patrick von Platen committed
300
301
302
        # 6. Add noise
        variance = 0
        if t > 0:
303
304
305
306
307
308
309
310
311
            device = model_output.device
            if device.type == "mps":
                # randn does not work reproducibly on mps
                variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
                variance_noise = variance_noise.to(device)
            else:
                variance_noise = torch.randn(
                    model_output.shape, generator=generator, device=device, dtype=model_output.dtype
                )
312
313
314
315
            if self.variance_type == "fixed_small_log":
                variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
            else:
                variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
Patrick von Platen's avatar
Patrick von Platen committed
316
317
318

        pred_prev_sample = pred_prev_sample + variance

319
320
321
        if not return_dict:
            return (pred_prev_sample,)

322
        return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
Patrick von Platen's avatar
Patrick von Platen committed
323

Partho's avatar
Partho committed
324
325
    def add_noise(
        self,
326
327
328
329
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
330
331
332
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
        self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
        timesteps = timesteps.to(original_samples.device)
333

anton-l's avatar
anton-l committed
334
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
335
336
337
338
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

anton-l's avatar
anton-l committed
339
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
340
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
341
342
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
343
344

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
anton-l's avatar
anton-l committed
345
        return noisy_samples
anton-l's avatar
anton-l committed
346

Patrick von Platen's avatar
improve  
Patrick von Platen committed
347
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
348
        return self.config.num_train_timesteps