scheduling_pndm.py 17 KB
Newer Older
1
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

17
import math
18
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
19

20
import numpy as np
21
import torch
22

23
from ..configuration_utils import ConfigMixin, register_to_config
24
from .scheduling_utils import SchedulerMixin, SchedulerOutput
25
26
27
28


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
29
30
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
31

32
33
34
35
36
37
38
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
39
                     prevent singularities.
40
41
42

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
43
    """
44

45
46
47
48
49
50
51
52
53
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
Patrick von Platen committed
54
55
56


class PNDMScheduler(SchedulerMixin, ConfigMixin):
57
58
59
60
    """
    Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
    namely Runge-Kutta method and a linear multi-step method.

61
62
63
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
64
    [`~ConfigMixin.from_config`] functions.
65

66
67
68
69
70
71
72
73
74
    For more details, see the original paper: https://arxiv.org/abs/2202.09778

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
Nathan Lambert's avatar
Nathan Lambert committed
75
76
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
77
78
79
80
81
82
83
        tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
        skip_prk_steps (`bool`):
            allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
            before plms steps; defaults to `False`.

    """

84
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
85
86
    def __init__(
        self,
Partho's avatar
Partho committed
87
88
89
90
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
91
        trained_betas: Optional[np.ndarray] = None,
Partho's avatar
Partho committed
92
93
        tensor_format: str = "pt",
        skip_prk_steps: bool = False,
Patrick von Platen's avatar
Patrick von Platen committed
94
    ):
95
96
        if trained_betas is not None:
            self.betas = np.asarray(trained_betas)
Patrick von Platen's avatar
Patrick von Platen committed
97
        if beta_schedule == "linear":
Nathan Lambert's avatar
Nathan Lambert committed
98
            self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
99
100
101
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
            self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
Patrick von Platen's avatar
Patrick von Platen committed
102
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
103
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
104
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
105
106
107
108
109
110
111
112
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)

        self.one = np.array(1.0)

Patrick von Platen's avatar
Patrick von Platen committed
113
114
        # For now we only support F-PNDM, i.e. the runge-kutta method
        # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
Patrick von Platen's avatar
Patrick von Platen committed
115
        # mainly at formula (9), (12), (13) and the Algorithm 2.
Patrick von Platen's avatar
Patrick von Platen committed
116
117
118
        self.pndm_order = 4

        # running values
Patrick von Platen's avatar
Patrick von Platen committed
119
        self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
120
        self.counter = 0
121
        self.cur_sample = None
Patrick von Platen's avatar
Patrick von Platen committed
122
123
        self.ets = []

124
125
        # setable values
        self.num_inference_steps = None
Patrick von Platen's avatar
Patrick von Platen committed
126
        self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
127
        self._offset = 0
128
129
        self.prk_timesteps = None
        self.plms_timesteps = None
Patrick von Platen's avatar
Patrick von Platen committed
130
        self.timesteps = None
131
132
133

        self.tensor_format = tensor_format
        self.set_format(tensor_format=tensor_format)
Patrick von Platen's avatar
Patrick von Platen committed
134

Partho's avatar
Partho committed
135
    def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
136
137
138
139
140
141
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
Nathan Lambert's avatar
Nathan Lambert committed
142
143
            offset (`int`):
                optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
144
        """
145
        self.num_inference_steps = num_inference_steps
146
147
148
149
        step_ratio = self.config.num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
        self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist()
150
        self._offset = offset
151
        self._timesteps = np.array([t + self._offset for t in self._timesteps])
152
153
154
155
156

        if self.config.skip_prk_steps:
            # for some models like stable diffusion the prk steps can/should be skipped to
            # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
            # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
157
            self.prk_timesteps = np.array([])
158
159
160
            self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
                ::-1
            ].copy()
161
162
163
164
        else:
            prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
                np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
            )
165
166
167
168
            self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
            self.plms_timesteps = self._timesteps[:-3][
                ::-1
            ].copy()  # we copy to avoid having negative strides which are not supported by torch.from_numpy
Patrick von Platen's avatar
Patrick von Platen committed
169

170
        self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
Patrick von Platen's avatar
Patrick von Platen committed
171

172
        self.ets = []
Patrick von Platen's avatar
Patrick von Platen committed
173
        self.counter = 0
174
        self.set_format(tensor_format=self.tensor_format)
Patrick von Platen's avatar
Patrick von Platen committed
175

Patrick von Platen's avatar
Patrick von Platen committed
176
177
178
179
180
    def step(
        self,
        model_output: Union[torch.FloatTensor, np.ndarray],
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
181
182
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
183
184
185
186
187
188
189
190
191
192
193
194
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
195

196
        Returns:
197
198
199
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
200
201

        """
202
        if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
203
            return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
204
        else:
205
            return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
206

207
208
    def step_prk(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
209
        model_output: Union[torch.FloatTensor, np.ndarray],
210
211
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
212
213
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
214
215
216
        """
        Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
        solution to the differential equation.
217
218
219
220
221
222
223
224
225

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
226
227
            [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
            True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
228

Nathan Lambert's avatar
Nathan Lambert committed
229
        """
230
231
232
233
234
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
235
236
237
        diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
        prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
        timestep = self.prk_timesteps[self.counter // 4 * 4]
Patrick von Platen's avatar
Patrick von Platen committed
238

Patrick von Platen's avatar
Patrick von Platen committed
239
        if self.counter % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
240
241
            self.cur_model_output += 1 / 6 * model_output
            self.ets.append(model_output)
242
            self.cur_sample = sample
Patrick von Platen's avatar
Patrick von Platen committed
243
        elif (self.counter - 1) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
244
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
245
        elif (self.counter - 2) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
246
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
247
        elif (self.counter - 3) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
248
249
            model_output = self.cur_model_output + 1 / 6 * model_output
            self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
250

Patrick von Platen's avatar
Patrick von Platen committed
251
252
253
        # cur_sample should not be `None`
        cur_sample = self.cur_sample if self.cur_sample is not None else sample

Patrick von Platen's avatar
Patrick von Platen committed
254
255
256
        prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
        self.counter += 1

257
258
259
260
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
261

262
263
    def step_plms(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
264
        model_output: Union[torch.FloatTensor, np.ndarray],
265
266
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
267
268
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
269
270
271
        """
        Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
        times to approximate the solution.
272
273
274
275
276
277
278
279
280

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
281
282
            [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
            True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
283

Nathan Lambert's avatar
Nathan Lambert committed
284
        """
285
286
287
288
289
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

290
        if not self.config.skip_prk_steps and len(self.ets) < 3:
Patrick von Platen's avatar
Patrick von Platen committed
291
292
293
294
295
296
297
            raise ValueError(
                f"{self.__class__} can only be run AFTER scheduler has been run "
                "in 'prk' mode for at least 12 iterations "
                "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
                "for more information."
            )

Patrick von Platen's avatar
Patrick von Platen committed
298
        prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
Patrick von Platen's avatar
Patrick von Platen committed
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        if self.counter != 1:
            self.ets.append(model_output)
        else:
            prev_timestep = timestep
            timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

        if len(self.ets) == 1 and self.counter == 0:
            model_output = model_output
            self.cur_sample = sample
        elif len(self.ets) == 1 and self.counter == 1:
            model_output = (model_output + self.ets[-1]) / 2
            sample = self.cur_sample
            self.cur_sample = None
        elif len(self.ets) == 2:
            model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
        elif len(self.ets) == 3:
            model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
        else:
            model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
Patrick von Platen's avatar
Patrick von Platen committed
319

Patrick von Platen's avatar
Patrick von Platen committed
320
321
322
        prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
        self.counter += 1

323
324
325
326
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
327

Patrick von Platen's avatar
Patrick von Platen committed
328
    def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
Patrick von Platen's avatar
Patrick von Platen committed
329
330
331
332
333
334
335
336
337
338
        # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
        # this function computes x_(t−δ) using the formula of (9)
        # Note that x_t needs to be added to both sides of the equation

        # Notation (<variable name> -> <name in paper>
        # alpha_prod_t -> α_t
        # alpha_prod_t_prev -> α_(t−δ)
        # beta_prod_t -> (1 - α_t)
        # beta_prod_t_prev -> (1 - α_(t−δ))
        # sample -> x_t
Patrick von Platen's avatar
Patrick von Platen committed
339
        # model_output -> e_θ(x_t, t)
Patrick von Platen's avatar
Patrick von Platen committed
340
        # prev_sample -> x_(t−δ)
341
342
        alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
        alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
Patrick von Platen's avatar
Patrick von Platen committed
343
344
345
346
347
348
349
350
351
352
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        # corresponds to (α_(t−δ) - α_t) divided by
        # denominator of x_t in formula (9) and plus 1
        # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
        # sqrt(α_(t−δ)) / sqrt(α_t))
        sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)

        # corresponds to denominator of e_θ(x_t, t) in formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
353
        model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
Patrick von Platen's avatar
Patrick von Platen committed
354
355
356
357
            alpha_prod_t * beta_prod_t * alpha_prod_t_prev
        ) ** (0.5)

        # full formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
358
359
360
        prev_sample = (
            sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
        )
Patrick von Platen's avatar
Patrick von Platen committed
361
362

        return prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
363

Partho's avatar
Partho committed
364
365
366
367
368
369
    def add_noise(
        self,
        original_samples: Union[torch.FloatTensor, np.ndarray],
        noise: Union[torch.FloatTensor, np.ndarray],
        timesteps: Union[torch.IntTensor, np.ndarray],
    ) -> torch.Tensor:
370
371
        # mps requires indices to be in the same device, so we use cpu as is the default with cuda
        timesteps = timesteps.to(self.alphas_cumprod.device)
372
373
374
375
376
377
378
379
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

Patrick von Platen's avatar
Patrick von Platen committed
380
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
381
        return self.config.num_train_timesteps