scheduling_pndm.py 18.4 KB
Newer Older
1
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

17
import math
18
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
19

20
import numpy as np
21
import torch
22

23
from ..configuration_utils import ConfigMixin, register_to_config
24
from ..utils import deprecate
25
from .scheduling_utils import SchedulerMixin, SchedulerOutput
26
27
28
29


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
30
31
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
32

33
34
35
36
37
38
39
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
40
                     prevent singularities.
41
42
43

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
44
    """
45

46
47
48
49
50
51
52
53
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
54
    return torch.tensor(betas, dtype=torch.float32)
Patrick von Platen's avatar
Patrick von Platen committed
55
56
57


class PNDMScheduler(SchedulerMixin, ConfigMixin):
58
59
60
61
    """
    Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
    namely Runge-Kutta method and a linear multi-step method.

62
63
64
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
65
    [`~ConfigMixin.from_config`] functions.
66

67
68
69
70
71
72
73
74
75
    For more details, see the original paper: https://arxiv.org/abs/2202.09778

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
76
77
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
78
79
80
        skip_prk_steps (`bool`):
            allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
            before plms steps; defaults to `False`.
81
82
83
84
85
86
87
88
        set_alpha_to_one (`bool`, default `False`):
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        steps_offset (`int`, default `0`):
            an offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
89
90
91

    """

92
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
93
94
    def __init__(
        self,
Partho's avatar
Partho committed
95
96
97
98
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
99
        trained_betas: Optional[np.ndarray] = None,
Partho's avatar
Partho committed
100
        skip_prk_steps: bool = False,
101
102
        set_alpha_to_one: bool = False,
        steps_offset: int = 0,
103
        **kwargs,
Patrick von Platen's avatar
Patrick von Platen committed
104
    ):
105
106
107
108
109
110
        deprecate(
            "tensor_format",
            "0.5.0",
            "If you're running your code in PyTorch, you can safely remove this argument.",
            take_from=kwargs,
        )
111

112
        if trained_betas is not None:
113
            self.betas = torch.from_numpy(trained_betas)
114
        elif beta_schedule == "linear":
115
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
116
117
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
118
119
120
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
Patrick von Platen's avatar
Patrick von Platen committed
121
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
122
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
123
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
124
125
126
127
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
128
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
Patrick von Platen's avatar
Patrick von Platen committed
129

130
        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
131

132
133
134
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

Patrick von Platen's avatar
Patrick von Platen committed
135
136
        # For now we only support F-PNDM, i.e. the runge-kutta method
        # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
Patrick von Platen's avatar
Patrick von Platen committed
137
        # mainly at formula (9), (12), (13) and the Algorithm 2.
Patrick von Platen's avatar
Patrick von Platen committed
138
139
140
        self.pndm_order = 4

        # running values
Patrick von Platen's avatar
Patrick von Platen committed
141
        self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
142
        self.counter = 0
143
        self.cur_sample = None
Patrick von Platen's avatar
Patrick von Platen committed
144
145
        self.ets = []

146
147
        # setable values
        self.num_inference_steps = None
Patrick von Platen's avatar
Patrick von Platen committed
148
        self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
149
150
        self.prk_timesteps = None
        self.plms_timesteps = None
Patrick von Platen's avatar
Patrick von Platen committed
151
        self.timesteps = None
152

153
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
154
155
156
157
158
159
160
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
161
162
163
164
        deprecated_offset = deprecate(
            "offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
        )
        offset = deprecated_offset or self.config.steps_offset
165

166
        self.num_inference_steps = num_inference_steps
167
168
169
        step_ratio = self.config.num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
170
171
        self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
        self._timesteps += offset
172
173
174
175
176

        if self.config.skip_prk_steps:
            # for some models like stable diffusion the prk steps can/should be skipped to
            # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
            # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
177
            self.prk_timesteps = np.array([])
178
179
180
            self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
                ::-1
            ].copy()
181
182
183
184
        else:
            prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
                np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
            )
185
186
187
188
            self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
            self.plms_timesteps = self._timesteps[:-3][
                ::-1
            ].copy()  # we copy to avoid having negative strides which are not supported by torch.from_numpy
Patrick von Platen's avatar
Patrick von Platen committed
189

190
191
        timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
        self.timesteps = torch.from_numpy(timesteps).to(device)
Patrick von Platen's avatar
Patrick von Platen committed
192

193
        self.ets = []
Patrick von Platen's avatar
Patrick von Platen committed
194
        self.counter = 0
Patrick von Platen's avatar
Patrick von Platen committed
195

Patrick von Platen's avatar
Patrick von Platen committed
196
197
    def step(
        self,
198
        model_output: torch.FloatTensor,
Patrick von Platen's avatar
Patrick von Platen committed
199
        timestep: int,
200
        sample: torch.FloatTensor,
201
202
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
203
204
205
206
207
208
209
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.

        Args:
210
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
211
            timestep (`int`): current discrete timestep in the diffusion chain.
212
            sample (`torch.FloatTensor`):
213
214
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
215

216
        Returns:
217
218
219
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
220
221

        """
222
        if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
223
            return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
224
        else:
225
            return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
226

227
228
    def step_prk(
        self,
229
        model_output: torch.FloatTensor,
230
        timestep: int,
231
        sample: torch.FloatTensor,
232
233
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
234
235
236
        """
        Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
        solution to the differential equation.
237
238

        Args:
239
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
240
            timestep (`int`): current discrete timestep in the diffusion chain.
241
            sample (`torch.FloatTensor`):
242
243
244
245
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
246
247
            [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
            True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
248

Nathan Lambert's avatar
Nathan Lambert committed
249
        """
250
251
252
253
254
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
255
        diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
256
        prev_timestep = timestep - diff_to_prev
Patrick von Platen's avatar
Patrick von Platen committed
257
        timestep = self.prk_timesteps[self.counter // 4 * 4]
Patrick von Platen's avatar
Patrick von Platen committed
258

Patrick von Platen's avatar
Patrick von Platen committed
259
        if self.counter % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
260
261
            self.cur_model_output += 1 / 6 * model_output
            self.ets.append(model_output)
262
            self.cur_sample = sample
Patrick von Platen's avatar
Patrick von Platen committed
263
        elif (self.counter - 1) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
264
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
265
        elif (self.counter - 2) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
266
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
267
        elif (self.counter - 3) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
268
269
            model_output = self.cur_model_output + 1 / 6 * model_output
            self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
270

Patrick von Platen's avatar
Patrick von Platen committed
271
272
273
        # cur_sample should not be `None`
        cur_sample = self.cur_sample if self.cur_sample is not None else sample

Patrick von Platen's avatar
Patrick von Platen committed
274
275
276
        prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
        self.counter += 1

277
278
279
280
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
281

282
283
    def step_plms(
        self,
284
        model_output: torch.FloatTensor,
285
        timestep: int,
286
        sample: torch.FloatTensor,
287
288
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
289
290
291
        """
        Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
        times to approximate the solution.
292
293

        Args:
294
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
295
            timestep (`int`): current discrete timestep in the diffusion chain.
296
            sample (`torch.FloatTensor`):
297
298
299
300
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
301
302
            [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
            True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
303

Nathan Lambert's avatar
Nathan Lambert committed
304
        """
305
306
307
308
309
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

310
        if not self.config.skip_prk_steps and len(self.ets) < 3:
Patrick von Platen's avatar
Patrick von Platen committed
311
312
313
314
315
316
317
            raise ValueError(
                f"{self.__class__} can only be run AFTER scheduler has been run "
                "in 'prk' mode for at least 12 iterations "
                "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
                "for more information."
            )

318
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
319

320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        if self.counter != 1:
            self.ets.append(model_output)
        else:
            prev_timestep = timestep
            timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

        if len(self.ets) == 1 and self.counter == 0:
            model_output = model_output
            self.cur_sample = sample
        elif len(self.ets) == 1 and self.counter == 1:
            model_output = (model_output + self.ets[-1]) / 2
            sample = self.cur_sample
            self.cur_sample = None
        elif len(self.ets) == 2:
            model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
        elif len(self.ets) == 3:
            model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
        else:
            model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
Patrick von Platen's avatar
Patrick von Platen committed
339

Patrick von Platen's avatar
Patrick von Platen committed
340
341
342
        prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
        self.counter += 1

343
344
345
346
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
347

348
349
350
351
352
353
354
355
356
357
358
359
360
    def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

361
    def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
Patrick von Platen's avatar
Patrick von Platen committed
362
363
364
365
366
367
368
369
370
371
        # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
        # this function computes x_(t−δ) using the formula of (9)
        # Note that x_t needs to be added to both sides of the equation

        # Notation (<variable name> -> <name in paper>
        # alpha_prod_t -> α_t
        # alpha_prod_t_prev -> α_(t−δ)
        # beta_prod_t -> (1 - α_t)
        # beta_prod_t_prev -> (1 - α_(t−δ))
        # sample -> x_t
Patrick von Platen's avatar
Patrick von Platen committed
372
        # model_output -> e_θ(x_t, t)
Patrick von Platen's avatar
Patrick von Platen committed
373
        # prev_sample -> x_(t−δ)
374
375
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Patrick von Platen's avatar
Patrick von Platen committed
376
377
378
379
380
381
382
383
384
385
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        # corresponds to (α_(t−δ) - α_t) divided by
        # denominator of x_t in formula (9) and plus 1
        # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
        # sqrt(α_(t−δ)) / sqrt(α_t))
        sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)

        # corresponds to denominator of e_θ(x_t, t) in formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
386
        model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
Patrick von Platen's avatar
Patrick von Platen committed
387
388
389
390
            alpha_prod_t * beta_prod_t * alpha_prod_t_prev
        ) ** (0.5)

        # full formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
391
392
393
        prev_sample = (
            sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
        )
Patrick von Platen's avatar
Patrick von Platen committed
394
395

        return prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
396

Partho's avatar
Partho committed
397
398
    def add_noise(
        self,
399
400
401
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
Partho's avatar
Partho committed
402
    ) -> torch.Tensor:
403
404
405
406
407
408
        if self.alphas_cumprod.device != original_samples.device:
            self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)

        if timesteps.device != original_samples.device:
            timesteps = timesteps.to(original_samples.device)

409
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
410
411
412
413
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

414
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
415
416
417
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
418
419
420
421

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

Patrick von Platen's avatar
Patrick von Platen committed
422
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
423
        return self.config.num_train_timesteps