scheduling_ddim.py 12.7 KB
Newer Older
1
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16
17

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

Patrick von Platen's avatar
Patrick von Platen committed
18
import math
19
import warnings
20
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
21

Patrick von Platen's avatar
Patrick von Platen committed
22
import numpy as np
23
import torch
Patrick von Platen's avatar
Patrick von Platen committed
24

25
from ..configuration_utils import ConfigMixin, register_to_config
26
from .scheduling_utils import SchedulerMixin, SchedulerOutput
27
28


29
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
30
    """
Patrick von Platen's avatar
Patrick von Platen committed
31
32
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
33

34
35
36
37
38
39
40
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
41
                     prevent singularities.
42
43
44

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
45
    """
46

47
    def alpha_bar(time_step):
48
49
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

50
51
52
53
54
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
55
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
Patrick von Platen committed
56
57


Patrick von Platen's avatar
Patrick von Platen committed
58
class DDIMScheduler(SchedulerMixin, ConfigMixin):
59
60
61
62
    """
    Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
    diffusion probabilistic models (DDPMs) with non-Markovian guidance.

63
64
65
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
66
    [`~ConfigMixin.from_config`] functions.
67

68
69
70
71
72
73
74
75
76
    For more details, see the original paper: https://arxiv.org/abs/2010.02502

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
77
78
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
79
80
81
        clip_sample (`bool`, default `True`):
            option to clip predicted sample between -1 and 1 for numerical stability.
        set_alpha_to_one (`bool`, default `True`):
82
83
84
85
86
87
88
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        steps_offset (`int`, default `0`):
            an offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
89
90
91
92
        tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.

    """

93
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
94
95
    def __init__(
        self,
96
97
98
99
100
101
102
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
        clip_sample: bool = True,
        set_alpha_to_one: bool = True,
103
        steps_offset: int = 0,
104
        tensor_format: str = "pt",
Patrick von Platen's avatar
Patrick von Platen committed
105
    ):
106
        if trained_betas is not None:
107
            self.betas = np.asarray(trained_betas)
108
        if beta_schedule == "linear":
Nathan Lambert's avatar
Nathan Lambert committed
109
            self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
110
111
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
Nathan Lambert's avatar
Nathan Lambert committed
112
            self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
Patrick von Platen's avatar
Patrick von Platen committed
113
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
114
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
115
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
116
117
118
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

119
120
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
121
122
123

        # At every step in ddim, we are looking into the previous alphas_cumprod
        # For the final step, there is no previous alphas_cumprod because we are already at 0
124
        # `set_alpha_to_one` decides whether we set this parameter simply to one or
125
        # whether we use the final alpha of the "non-previous" one.
126
        self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
Patrick von Platen's avatar
Patrick von Platen committed
127

128
        # setable values
129
130
        self.num_inference_steps = None
        self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
Patrick von Platen's avatar
Patrick von Platen committed
131

132
        self.tensor_format = tensor_format
Patrick von Platen's avatar
Patrick von Platen committed
133
134
        self.set_format(tensor_format=tensor_format)

135
136
    def _get_variance(self, timestep, prev_timestep):
        alpha_prod_t = self.alphas_cumprod[timestep]
137
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
138
139
140
141
142
143
144
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

145
    def set_timesteps(self, num_inference_steps: int, **kwargs):
146
147
148
149
150
151
152
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
153
154
155
156
157
158
159
160
161
162
163
164

        offset = self.config.steps_offset

        if "offset" in kwargs:
            warnings.warn(
                "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
                " Please pass `steps_offset` to `__init__` instead.",
                DeprecationWarning,
            )

            offset = kwargs["offset"]

165
        self.num_inference_steps = num_inference_steps
166
167
168
169
        step_ratio = self.config.num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
        self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
170
        self.timesteps += offset
171
172
173
174
        self.set_format(tensor_format=self.tensor_format)

    def step(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
175
        model_output: Union[torch.FloatTensor, np.ndarray],
176
177
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
Patrick von Platen's avatar
Patrick von Platen committed
178
179
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
180
        generator=None,
181
182
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            eta (`float`): weight of noise for added noise in diffusion step.
            use_clipped_model_output (`bool`): TODO
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
198
199
200
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
201
202

        """
203
204
205
206
207
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
208
209
210
211
212
        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
213
        # - pred_original_sample -> f_theta(x_t, t) or x_0
Patrick von Platen's avatar
Patrick von Platen committed
214
215
        # - std_dev_t -> sigma_t
        # - eta -> η
216
        # - pred_sample_direction -> "direction pointing to x_t"
217
        # - pred_prev_sample -> "x_t-1"
Patrick von Platen's avatar
Patrick von Platen committed
218

219
        # 1. get previous step value (=t-1)
Nathan Lambert's avatar
Nathan Lambert committed
220
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
221
222

        # 2. compute alphas, betas
223
        alpha_prod_t = self.alphas_cumprod[timestep]
224
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
225
226
        beta_prod_t = 1 - alpha_prod_t

227
        # 3. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
228
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
229
        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
230
231

        # 4. Clip "predicted x_0"
232
        if self.config.clip_sample:
233
            pred_original_sample = self.clip(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
234
235
236

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
237
        variance = self._get_variance(timestep, prev_timestep)
Patrick von Platen's avatar
Patrick von Platen committed
238
        std_dev_t = eta * variance ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
239

Patrick von Platen's avatar
Patrick von Platen committed
240
241
242
        if use_clipped_model_output:
            # the model_output is always re-derived from the clipped x_0 in Glide
            model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
anton-l's avatar
anton-l committed
243

Patrick von Platen's avatar
Patrick von Platen committed
244
        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
245
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
Patrick von Platen's avatar
Patrick von Platen committed
246
247

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
248
249
250
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
Patrick von Platen's avatar
Patrick von Platen committed
251
252
            device = model_output.device if torch.is_tensor(model_output) else "cpu"
            noise = torch.randn(model_output.shape, generator=generator).to(device)
253
254
            variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise

Patrick von Platen's avatar
Patrick von Platen committed
255
            if not torch.is_tensor(model_output):
256
257
258
                variance = variance.numpy()

            prev_sample = prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
259

260
261
262
263
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
264

265
266
267
268
269
270
    def add_noise(
        self,
        original_samples: Union[torch.FloatTensor, np.ndarray],
        noise: Union[torch.FloatTensor, np.ndarray],
        timesteps: Union[torch.IntTensor, np.ndarray],
    ) -> Union[torch.FloatTensor, np.ndarray]:
271
272
        if self.tensor_format == "pt":
            timesteps = timesteps.to(self.alphas_cumprod.device)
273
274
275
276
277
278
279
280
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

Patrick von Platen's avatar
Patrick von Platen committed
281
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
282
        return self.config.num_train_timesteps