scheduling_pndm.py 15.8 KB
Newer Older
1
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

17
import math
18
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
19

20
import numpy as np
21
import torch
22

23
from ..configuration_utils import ConfigMixin, register_to_config
24
from .scheduling_utils import SchedulerMixin, SchedulerOutput
25
26
27
28


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
29
30
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
31

32
33
34
35
36
37
38
    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
39
                     prevent singularities.
40
41
42

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
43
    """
44

45
46
47
48
49
50
51
52
53
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
Patrick von Platen committed
54
55
56


class PNDMScheduler(SchedulerMixin, ConfigMixin):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    """
    Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
    namely Runge-Kutta method and a linear multi-step method.

    For more details, see the original paper: https://arxiv.org/abs/2202.09778

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
        trained_betas (`np.ndarray`, optional): TODO
        tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
        skip_prk_steps (`bool`):
            allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
            before plms steps; defaults to `False`.

    """

78
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
79
80
    def __init__(
        self,
Partho's avatar
Partho committed
81
82
83
84
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
85
        trained_betas: Optional[np.ndarray] = None,
Partho's avatar
Partho committed
86
87
        tensor_format: str = "pt",
        skip_prk_steps: bool = False,
Patrick von Platen's avatar
Patrick von Platen committed
88
    ):
89
90
        if trained_betas is not None:
            self.betas = np.asarray(trained_betas)
Patrick von Platen's avatar
Patrick von Platen committed
91
        if beta_schedule == "linear":
Nathan Lambert's avatar
Nathan Lambert committed
92
            self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
93
94
95
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
            self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
Patrick von Platen's avatar
Patrick von Platen committed
96
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
97
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
98
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
99
100
101
102
103
104
105
106
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)

        self.one = np.array(1.0)

Patrick von Platen's avatar
Patrick von Platen committed
107
108
        # For now we only support F-PNDM, i.e. the runge-kutta method
        # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
Patrick von Platen's avatar
Patrick von Platen committed
109
        # mainly at formula (9), (12), (13) and the Algorithm 2.
Patrick von Platen's avatar
Patrick von Platen committed
110
111
112
        self.pndm_order = 4

        # running values
Patrick von Platen's avatar
Patrick von Platen committed
113
        self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
114
        self.counter = 0
115
        self.cur_sample = None
Patrick von Platen's avatar
Patrick von Platen committed
116
117
        self.ets = []

118
119
        # setable values
        self.num_inference_steps = None
Patrick von Platen's avatar
Patrick von Platen committed
120
        self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
121
        self._offset = 0
122
123
        self.prk_timesteps = None
        self.plms_timesteps = None
Patrick von Platen's avatar
Patrick von Platen committed
124
        self.timesteps = None
125
126
127

        self.tensor_format = tensor_format
        self.set_format(tensor_format=tensor_format)
Patrick von Platen's avatar
Patrick von Platen committed
128

Partho's avatar
Partho committed
129
    def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
130
131
132
133
134
135
136
137
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            offset (`int`): TODO
        """
138
        self.num_inference_steps = num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
139
        self._timesteps = list(
Nathan Lambert's avatar
Nathan Lambert committed
140
141
            range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
        )
142
        self._offset = offset
143
        self._timesteps = np.array([t + self._offset for t in self._timesteps])
144
145
146
147
148

        if self.config.skip_prk_steps:
            # for some models like stable diffusion the prk steps can/should be skipped to
            # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
            # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
149
            self.prk_timesteps = np.array([])
150
151
152
            self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
                ::-1
            ].copy()
153
154
155
156
        else:
            prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
                np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
            )
157
158
159
160
            self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
            self.plms_timesteps = self._timesteps[:-3][
                ::-1
            ].copy()  # we copy to avoid having negative strides which are not supported by torch.from_numpy
Patrick von Platen's avatar
Patrick von Platen committed
161

162
        self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
Patrick von Platen's avatar
Patrick von Platen committed
163

164
        self.ets = []
Patrick von Platen's avatar
Patrick von Platen committed
165
        self.counter = 0
166
        self.set_format(tensor_format=self.tensor_format)
Patrick von Platen's avatar
Patrick von Platen committed
167

Patrick von Platen's avatar
Patrick von Platen committed
168
169
170
171
172
    def step(
        self,
        model_output: Union[torch.FloatTensor, np.ndarray],
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
173
174
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
175
176
177
178
179
180
181
182
183
184
185
186
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
187

188
189
190
191
        Returns:
            `SchedulerOutput`: updated sample in the diffusion chain.

        """
192
        if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
193
            return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
194
        else:
195
            return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
Patrick von Platen's avatar
Patrick von Platen committed
196

197
198
    def step_prk(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
199
        model_output: Union[torch.FloatTensor, np.ndarray],
200
201
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
202
203
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
204
205
206
        """
        Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
        solution to the differential equation.
207
208
209
210
211
212
213
214
215
216
217

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
            prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.

Nathan Lambert's avatar
Nathan Lambert committed
218
        """
219
220
221
222
223
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
224
225
226
        diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
        prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
        timestep = self.prk_timesteps[self.counter // 4 * 4]
Patrick von Platen's avatar
Patrick von Platen committed
227

Patrick von Platen's avatar
Patrick von Platen committed
228
        if self.counter % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
229
230
            self.cur_model_output += 1 / 6 * model_output
            self.ets.append(model_output)
231
            self.cur_sample = sample
Patrick von Platen's avatar
Patrick von Platen committed
232
        elif (self.counter - 1) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
233
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
234
        elif (self.counter - 2) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
235
            self.cur_model_output += 1 / 3 * model_output
Patrick von Platen's avatar
Patrick von Platen committed
236
        elif (self.counter - 3) % 4 == 0:
Patrick von Platen's avatar
Patrick von Platen committed
237
238
            model_output = self.cur_model_output + 1 / 6 * model_output
            self.cur_model_output = 0
Patrick von Platen's avatar
Patrick von Platen committed
239

Patrick von Platen's avatar
Patrick von Platen committed
240
241
242
        # cur_sample should not be `None`
        cur_sample = self.cur_sample if self.cur_sample is not None else sample

Patrick von Platen's avatar
Patrick von Platen committed
243
244
245
        prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
        self.counter += 1

246
247
248
249
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
250

251
252
    def step_plms(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
253
        model_output: Union[torch.FloatTensor, np.ndarray],
254
255
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
256
257
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
258
259
260
        """
        Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
        times to approximate the solution.
261
262
263
264
265
266
267
268
269
270
271

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
            prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.

Nathan Lambert's avatar
Nathan Lambert committed
272
        """
273
274
275
276
277
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

278
        if not self.config.skip_prk_steps and len(self.ets) < 3:
Patrick von Platen's avatar
Patrick von Platen committed
279
280
281
282
283
284
285
            raise ValueError(
                f"{self.__class__} can only be run AFTER scheduler has been run "
                "in 'prk' mode for at least 12 iterations "
                "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
                "for more information."
            )

Patrick von Platen's avatar
Patrick von Platen committed
286
        prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
Patrick von Platen's avatar
Patrick von Platen committed
287

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        if self.counter != 1:
            self.ets.append(model_output)
        else:
            prev_timestep = timestep
            timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

        if len(self.ets) == 1 and self.counter == 0:
            model_output = model_output
            self.cur_sample = sample
        elif len(self.ets) == 1 and self.counter == 1:
            model_output = (model_output + self.ets[-1]) / 2
            sample = self.cur_sample
            self.cur_sample = None
        elif len(self.ets) == 2:
            model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
        elif len(self.ets) == 3:
            model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
        else:
            model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
Patrick von Platen's avatar
Patrick von Platen committed
307

Patrick von Platen's avatar
Patrick von Platen committed
308
309
310
        prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
        self.counter += 1

311
312
313
314
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
315

Patrick von Platen's avatar
Patrick von Platen committed
316
    def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
Patrick von Platen's avatar
Patrick von Platen committed
317
318
319
320
321
322
323
324
325
326
        # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
        # this function computes x_(t−δ) using the formula of (9)
        # Note that x_t needs to be added to both sides of the equation

        # Notation (<variable name> -> <name in paper>
        # alpha_prod_t -> α_t
        # alpha_prod_t_prev -> α_(t−δ)
        # beta_prod_t -> (1 - α_t)
        # beta_prod_t_prev -> (1 - α_(t−δ))
        # sample -> x_t
Patrick von Platen's avatar
Patrick von Platen committed
327
        # model_output -> e_θ(x_t, t)
Patrick von Platen's avatar
Patrick von Platen committed
328
        # prev_sample -> x_(t−δ)
329
330
        alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
        alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
Patrick von Platen's avatar
Patrick von Platen committed
331
332
333
334
335
336
337
338
339
340
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        # corresponds to (α_(t−δ) - α_t) divided by
        # denominator of x_t in formula (9) and plus 1
        # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
        # sqrt(α_(t−δ)) / sqrt(α_t))
        sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)

        # corresponds to denominator of e_θ(x_t, t) in formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
341
        model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
Patrick von Platen's avatar
Patrick von Platen committed
342
343
344
345
            alpha_prod_t * beta_prod_t * alpha_prod_t_prev
        ) ** (0.5)

        # full formula (9)
Patrick von Platen's avatar
Patrick von Platen committed
346
347
348
        prev_sample = (
            sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
        )
Patrick von Platen's avatar
Patrick von Platen committed
349
350

        return prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
351

Partho's avatar
Partho committed
352
353
354
355
356
357
    def add_noise(
        self,
        original_samples: Union[torch.FloatTensor, np.ndarray],
        noise: Union[torch.FloatTensor, np.ndarray],
        timesteps: Union[torch.IntTensor, np.ndarray],
    ) -> torch.Tensor:
358
359
        # mps requires indices to be in the same device, so we use cpu as is the default with cuda
        timesteps = timesteps.to(self.alphas_cumprod.device)
360
361
362
363
364
365
366
367
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

Patrick von Platen's avatar
Patrick von Platen committed
368
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
369
        return self.config.num_train_timesteps