scheduling_ddpm.py 11.5 KB
Newer Older
1
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
improve  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

anton-l's avatar
anton-l committed
17
import math
18
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
Patrick von Platen committed
20
import numpy as np
21
import torch
Patrick von Platen's avatar
improve  
Patrick von Platen committed
22

23
from ..configuration_utils import ConfigMixin, register_to_config
24
from .scheduling_utils import SchedulerMixin, SchedulerOutput
25
26
27
28


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
29
30
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
31

32
33
34
35
36
37
38
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
39
                     prevent singularities.
40
41
42

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
43
    """
44

45
46
47
48
49
50
51
52
53
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
54
55


Patrick von Platen's avatar
Patrick von Platen committed
56
class DDPMScheduler(SchedulerMixin, ConfigMixin):
57
58
59
60
    """
    Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
    Langevin dynamics sampling.

61
62
63
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
64
    [`~ConfigMixin.from_config`] functions.
65

66
67
68
69
70
71
72
73
74
    For more details, see the original paper: https://arxiv.org/abs/2006.11239

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
75
76
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
77
78
79
80
81
82
83
84
85
        variance_type (`str`):
            options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
            `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
        clip_sample (`bool`, default `True`):
            option to clip predicted sample between -1 and 1 for numerical stability.
        tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.

    """

86
    @register_to_config
Patrick von Platen's avatar
improve  
Patrick von Platen committed
87
88
    def __init__(
        self,
Partho's avatar
Partho committed
89
90
91
92
93
94
95
96
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
        variance_type: str = "fixed_small",
        clip_sample: bool = True,
        tensor_format: str = "pt",
Patrick von Platen's avatar
improve  
Patrick von Platen committed
97
    ):
98
99
100
        if trained_betas is not None:
            self.betas = np.asarray(trained_betas)
        elif beta_schedule == "linear":
Nathan Lambert's avatar
Nathan Lambert committed
101
            self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
102
103
104
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
            self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
anton-l's avatar
anton-l committed
105
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
106
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
107
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
108
109
110
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Patrick von Platen's avatar
Patrick von Platen committed
111
112
113
114
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
        self.one = np.array(1.0)

115
116
117
118
119
        # setable values
        self.num_inference_steps = None
        self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

        self.tensor_format = tensor_format
Patrick von Platen's avatar
Patrick von Platen committed
120
121
        self.set_format(tensor_format=tensor_format)

122
123
        self.variance_type = variance_type

Partho's avatar
Partho committed
124
    def set_timesteps(self, num_inference_steps: int):
125
126
127
128
129
130
131
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
Patrick von Platen's avatar
Patrick von Platen committed
132
        num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
133
134
135
136
137
138
        self.num_inference_steps = num_inference_steps
        self.timesteps = np.arange(
            0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
        )[::-1].copy()
        self.set_format(tensor_format=self.tensor_format)

139
    def _get_variance(self, t, predicted_variance=None, variance_type=None):
140
141
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
142

Kashif Rasul's avatar
Kashif Rasul committed
143
        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
144
        # and sample from it to get previous sample
Kashif Rasul's avatar
Kashif Rasul committed
145
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
146
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
Patrick von Platen's avatar
Patrick von Platen committed
147

148
149
150
        if variance_type is None:
            variance_type = self.config.variance_type

Patrick von Platen's avatar
Patrick von Platen committed
151
        # hacks - were probs added for training stability
152
        if variance_type == "fixed_small":
Patrick von Platen's avatar
Patrick von Platen committed
153
            variance = self.clip(variance, min_value=1e-20)
154
        # for rl-diffuser https://arxiv.org/abs/2205.09991
155
        elif variance_type == "fixed_small_log":
156
            variance = self.log(self.clip(variance, min_value=1e-20))
157
        elif variance_type == "fixed_large":
158
            variance = self.betas[t]
159
        elif variance_type == "fixed_large_log":
Patrick von Platen's avatar
Patrick von Platen committed
160
            # Glide max_log
161
            variance = self.log(self.betas[t])
162
163
164
165
166
167
168
        elif variance_type == "learned":
            return predicted_variance
        elif variance_type == "learned_range":
            min_log = variance
            max_log = self.betas[t]
            frac = (predicted_variance + 1) / 2
            variance = frac * max_log + (1 - frac) * min_log
Patrick von Platen's avatar
Patrick von Platen committed
169
170
171

        return variance

172
173
    def step(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
174
        model_output: Union[torch.FloatTensor, np.ndarray],
175
176
177
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
        predict_epsilon=True,
Patrick von Platen's avatar
Patrick von Platen committed
178
        generator=None,
179
180
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            eta (`float`): weight of noise for added noise in diffusion step.
            predict_epsilon (`bool`):
                optional flag to use when model predicts the samples directly instead of the noise, epsilon.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
197
198
199
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
200
201

        """
202
        t = timestep
203

204
205
206
207
208
        if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
            model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
        else:
            predicted_variance = None

Patrick von Platen's avatar
Patrick von Platen committed
209
        # 1. compute alphas, betas
210
211
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
212
213
214
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

215
        # 2. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
216
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
217
        if predict_epsilon:
Patrick von Platen's avatar
Patrick von Platen committed
218
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
219
        else:
Patrick von Platen's avatar
Patrick von Platen committed
220
            pred_original_sample = model_output
Patrick von Platen's avatar
Patrick von Platen committed
221
222

        # 3. Clip "predicted x_0"
223
        if self.config.clip_sample:
224
            pred_original_sample = self.clip(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
225

226
        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
Patrick von Platen's avatar
Patrick von Platen committed
227
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
228
229
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
        current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
Patrick von Platen's avatar
Patrick von Platen committed
230

231
        # 5. Compute predicted previous sample µ_t
Patrick von Platen's avatar
Patrick von Platen committed
232
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
233
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
Patrick von Platen's avatar
Patrick von Platen committed
234

Patrick von Platen's avatar
Patrick von Platen committed
235
236
237
        # 6. Add noise
        variance = 0
        if t > 0:
238
            noise = self.randn_like(model_output, generator=generator)
239
            variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
Patrick von Platen's avatar
Patrick von Platen committed
240
241
242

        pred_prev_sample = pred_prev_sample + variance

243
244
245
246
        if not return_dict:
            return (pred_prev_sample,)

        return SchedulerOutput(prev_sample=pred_prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
247

Partho's avatar
Partho committed
248
249
250
251
252
253
    def add_noise(
        self,
        original_samples: Union[torch.FloatTensor, np.ndarray],
        noise: Union[torch.FloatTensor, np.ndarray],
        timesteps: Union[torch.IntTensor, np.ndarray],
    ) -> Union[torch.FloatTensor, np.ndarray]:
anton-l's avatar
anton-l committed
254
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
255
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
anton-l's avatar
anton-l committed
256
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
257
258
259
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
anton-l's avatar
anton-l committed
260
        return noisy_samples
anton-l's avatar
anton-l committed
261

Patrick von Platen's avatar
improve  
Patrick von Platen committed
262
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
263
        return self.config.num_train_timesteps