scheduling_pndm.py 17.7 KB
Newer Older
1
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

17
import math
18
import warnings
19
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
20

21
import numpy as np
22
import torch
23

24
from ..configuration_utils import ConfigMixin, register_to_config
25
from .scheduling_utils import SchedulerMixin, SchedulerOutput
26
27
28
29


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
30
31
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
32

33
34
35
36
37
38
39
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
40
                     prevent singularities.
41
42
43

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
44
    """
45

46
47
48
49
50
51
52
53
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
54
    return torch.tensor(betas, dtype=torch.float32)
Patrick von Platen's avatar
Patrick von Platen committed
55
56
57


class PNDMScheduler(SchedulerMixin, ConfigMixin):
58
59
60
61
    """
    Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
    namely Runge-Kutta method and a linear multi-step method.

62
63
64
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
65
    [`~ConfigMixin.from_config`] functions.
66

67
68
69
70
71
72
73
74
75
    For more details, see the original paper: https://arxiv.org/abs/2202.09778

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
76
77
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
78
79
80
        skip_prk_steps (`bool`):
            allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
            before plms steps; defaults to `False`.
81
82
83
84
85
86
87
88
        set_alpha_to_one (`bool`, default `False`):
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        steps_offset (`int`, default `0`):
            an offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
89
90
91

    """

92
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
93
94
    def __init__(
        self,
Partho's avatar
Partho committed
95
96
97
98
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
99
        trained_betas: Optional[np.ndarray] = None,
Partho's avatar
Partho committed
100
        skip_prk_steps: bool = False,
101
102
        set_alpha_to_one: bool = False,
        steps_offset: int = 0,
Patrick von Platen's avatar
Patrick von Platen committed
103
    ):
104
        if trained_betas is not None:
105
            self.betas = torch.from_numpy(trained_betas)
Patrick von Platen's avatar
Patrick von Platen committed
106
        if beta_schedule == "linear":
107
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
108
109
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
110
111
112
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
Patrick von Platen's avatar
Patrick von Platen committed
113
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
114
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
115
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
116
117
118
119
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
120
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
Patrick von Platen's avatar
Patrick von Platen committed
121

122
        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
123

Patrick von Platen's avatar
Patrick von Platen committed
124
125
        # For now we only support F-PNDM, i.e. the runge-kutta method
        # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
Patrick von Platen's avatar
Patrick von Platen committed
126
        # mainly at formula (9), (12), (13) and the Algorithm 2.
Patrick von Platen's avatar
Patrick von Platen committed
127
128
129
        self.pndm_order = 4

        # running values
Patrick von Platen's avatar
Patrick von Platen committed
130
        self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
131
        self.counter = 0
132
        self.cur_sample = None
Patrick von Platen's avatar
Patrick von Platen committed
133
134
        self.ets = []

135
136
        # setable values
        self.num_inference_steps = None
Patrick von Platen's avatar
Patrick von Platen committed
137
        self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
138
139
        self.prk_timesteps = None
        self.plms_timesteps = None
Patrick von Platen's avatar
Patrick von Platen committed
140
        self.timesteps = None
141

142
    def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
143
144
145
146
147
148
149
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
150
151
152
153
154
155
156
157
158
159
160

        offset = self.config.steps_offset

        if "offset" in kwargs:
            warnings.warn(
                "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
                " Please pass `steps_offset` to `__init__` instead."
            )

            offset = kwargs["offset"]

161
        self.num_inference_steps = num_inference_steps
162
163
164
        step_ratio = self.config.num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
165
166
        self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
        self._timesteps += offset
167
168
169
170
171

        if self.config.skip_prk_steps:
            # for some models like stable diffusion the prk steps can/should be skipped to
            # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
            # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
172
            self.prk_timesteps = np.array([])
173
174
175
            self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
                ::-1
            ].copy()
176
177
178
179
        else:
            prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
                np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
            )
180
181
182
183
            self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
            self.plms_timesteps = self._timesteps[:-3][
                ::-1
            ].copy()  # we copy to avoid having negative strides which are not supported by torch.from_numpy
Patrick von Platen's avatar
Patrick von Platen committed
184

185
        self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
Patrick von Platen's avatar
Patrick von Platen committed
186

187
        self.ets = []
Patrick von Platen's avatar
Patrick von Platen committed
188
        self.counter = 0
Patrick von Platen's avatar
Patrick von Platen committed
189

Patrick von Platen's avatar
Patrick von Platen committed
190
191
    def step(
        self,
192
        model_output: torch.FloatTensor,
Patrick von Platen's avatar
Patrick von Platen committed
193
        timestep: int,
194
        sample: torch.FloatTensor,
195
196
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
197
198
199
200
201
202
203
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.

        Args:
204
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
205
            timestep (`int`): current discrete timestep in the diffusion chain.
206
            sample (`torch.FloatTensor`):
207
208
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
209

210
        Returns:
211
212
213
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
214
215

        """
216
        if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
217
            return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
218
        else:
219
            return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
220

221
222
    def step_prk(
        self,
223
        model_output: torch.FloatTensor,
224
        timestep: int,
225
        sample: torch.FloatTensor,
226
227
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
228
229
230
        """
        Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
        solution to the differential equation.
231
232

        Args:
233
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
234
            timestep (`int`): current discrete timestep in the diffusion chain.
235
            sample (`torch.FloatTensor`):
236
237
238
239
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
240
241
            [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
            True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
242

Nathan Lambert's avatar
Nathan Lambert committed
243
        """
244
245
246
247
248
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
249
        diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
250
        prev_timestep = timestep - diff_to_prev
Patrick von Platen's avatar
Patrick von Platen committed
251
        timestep = self.prk_timesteps[self.counter // 4 * 4]
Patrick von Platen's avatar
Patrick von Platen committed
252

Patrick von Platen's avatar
Patrick von Platen committed
253
        if self.counter % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
254
255
            self.cur_model_output += 1 / 6 * model_output
            self.ets.append(model_output)
256
            self.cur_sample = sample
Patrick von Platen's avatar
Patrick von Platen committed
257
        elif (self.counter - 1) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
258
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
259
        elif (self.counter - 2) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
260
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
261
        elif (self.counter - 3) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
262
263
            model_output = self.cur_model_output + 1 / 6 * model_output
            self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
264

Patrick von Platen's avatar
Patrick von Platen committed
265
266
267
        # cur_sample should not be `None`
        cur_sample = self.cur_sample if self.cur_sample is not None else sample

Patrick von Platen's avatar
Patrick von Platen committed
268
269
270
        prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
        self.counter += 1

271
272
273
274
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
275

276
277
    def step_plms(
        self,
278
        model_output: torch.FloatTensor,
279
        timestep: int,
280
        sample: torch.FloatTensor,
281
282
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
283
284
285
        """
        Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
        times to approximate the solution.
286
287

        Args:
288
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
289
            timestep (`int`): current discrete timestep in the diffusion chain.
290
            sample (`torch.FloatTensor`):
291
292
293
294
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
295
296
            [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
            True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
297

Nathan Lambert's avatar
Nathan Lambert committed
298
        """
299
300
301
302
303
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

304
        if not self.config.skip_prk_steps and len(self.ets) < 3:
Patrick von Platen's avatar
Patrick von Platen committed
305
306
307
308
309
310
311
            raise ValueError(
                f"{self.__class__} can only be run AFTER scheduler has been run "
                "in 'prk' mode for at least 12 iterations "
                "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
                "for more information."
            )

312
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        if self.counter != 1:
            self.ets.append(model_output)
        else:
            prev_timestep = timestep
            timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

        if len(self.ets) == 1 and self.counter == 0:
            model_output = model_output
            self.cur_sample = sample
        elif len(self.ets) == 1 and self.counter == 1:
            model_output = (model_output + self.ets[-1]) / 2
            sample = self.cur_sample
            self.cur_sample = None
        elif len(self.ets) == 2:
            model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
        elif len(self.ets) == 3:
            model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
        else:
            model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
Patrick von Platen's avatar
Patrick von Platen committed
333

Patrick von Platen's avatar
Patrick von Platen committed
334
335
336
        prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
        self.counter += 1

337
338
339
340
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
341

342
    def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
Patrick von Platen's avatar
Patrick von Platen committed
343
344
345
346
347
348
349
350
351
352
        # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
        # this function computes x_(t−δ) using the formula of (9)
        # Note that x_t needs to be added to both sides of the equation

        # Notation (<variable name> -> <name in paper>
        # alpha_prod_t -> α_t
        # alpha_prod_t_prev -> α_(t−δ)
        # beta_prod_t -> (1 - α_t)
        # beta_prod_t_prev -> (1 - α_(t−δ))
        # sample -> x_t
Patrick von Platen's avatar
Patrick von Platen committed
353
        # model_output -> e_θ(x_t, t)
Patrick von Platen's avatar
Patrick von Platen committed
354
        # prev_sample -> x_(t−δ)
355
356
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
357
358
359
360
361
362
363
364
365
366
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        # corresponds to (α_(t−δ) - α_t) divided by
        # denominator of x_t in formula (9) and plus 1
        # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
        # sqrt(α_(t−δ)) / sqrt(α_t))
        sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)

        # corresponds to denominator of e_θ(x_t, t) in formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
367
        model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
Patrick von Platen's avatar
Patrick von Platen committed
368
369
370
371
            alpha_prod_t * beta_prod_t * alpha_prod_t_prev
        ) ** (0.5)

        # full formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
372
373
374
        prev_sample = (
            sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
        )
Patrick von Platen's avatar
Patrick von Platen committed
375
376

        return prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
377

Partho's avatar
Partho committed
378
379
    def add_noise(
        self,
380
381
382
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
Partho's avatar
Partho committed
383
    ) -> torch.Tensor:
384
385
386
387
388
389
        if self.alphas_cumprod.device != original_samples.device:
            self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)

        if timesteps.device != original_samples.device:
            timesteps = timesteps.to(original_samples.device)

390
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
391
392
393
394
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

395
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
396
397
398
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
399
400
401
402

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

Patrick von Platen's avatar
Patrick von Platen committed
403
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
404
        return self.config.num_train_timesteps