scheduling_pndm_flax.py 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import flax
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput


28
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.

    Returns:
43
        betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    """

    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return jnp.array(betas, dtype=jnp.float32)


@flax.struct.dataclass
class PNDMSchedulerState:
    # setable values
60
    _timesteps: jnp.ndarray
61
62
    num_inference_steps: Optional[int] = None
    _offset: int = 0
63
64
65
    prk_timesteps: Optional[jnp.ndarray] = None
    plms_timesteps: Optional[jnp.ndarray] = None
    timesteps: Optional[jnp.ndarray] = None
66
67
68
69
70

    # running values
    cur_model_output: Optional[jnp.ndarray] = None
    counter: int = 0
    cur_sample: Optional[jnp.ndarray] = None
71
    ets: jnp.ndarray = jnp.array([])
72
73

    @classmethod
74
75
    def create(cls, num_train_timesteps: int):
        return cls(_timesteps=jnp.arange(0, num_train_timesteps)[::-1])
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101


@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
    state: PNDMSchedulerState


class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
    """
    Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
    namely Runge-Kutta method and a linear multi-step method.

    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
    [`~ConfigMixin.from_config`] functions.

    For more details, see the original paper: https://arxiv.org/abs/2202.09778

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
102
        trained_betas (`jnp.ndarray`, optional):
103
104
105
106
107
108
109
110
111
112
113
114
115
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
        skip_prk_steps (`bool`):
            allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
            before plms steps; defaults to `False`.
    """

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
116
        trained_betas: Optional[jnp.ndarray] = None,
117
118
119
        skip_prk_steps: bool = False,
    ):
        if trained_betas is not None:
120
            self.betas = jnp.asarray(trained_betas)
121
        if beta_schedule == "linear":
122
            self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
123
124
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
125
            self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
126
127
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
128
            self.betas = betas_for_alpha_bar(num_train_timesteps)
129
130
131
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

132
133
134
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)

135
136
137
138
139
        # For now we only support F-PNDM, i.e. the runge-kutta method
        # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
        # mainly at formula (9), (12), (13) and the Algorithm 2.
        self.pndm_order = 4

140
        self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
141
142
143
144
145
146
147
148
149

    def set_timesteps(
        self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0
    ) -> PNDMSchedulerState:
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            state (`PNDMSchedulerState`):
150
                the `FlaxPNDMScheduler` state data class instance.
151
152
153
154
155
156
157
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            offset (`int`):
                optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
        """
        step_ratio = self.config.num_train_timesteps // num_inference_steps
        # creates integer timesteps by multiplying by ratio
158
        # rounding to avoid issues when num_inference_step is power of 3
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
        _timesteps = _timesteps + offset

        state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps)

        if self.config.skip_prk_steps:
            # for some models like stable diffusion the prk steps can/should be skipped to
            # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
            # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
            state = state.replace(
                prk_timesteps=jnp.array([]),
                plms_timesteps=jnp.concatenate(
                    [state._timesteps[:-1], state._timesteps[-2:-1], state._timesteps[-1:]]
                )[::-1],
            )
        else:
            prk_timesteps = jnp.array(state._timesteps[-self.pndm_order :]).repeat(2) + jnp.tile(
                jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
            )

            state = state.replace(
                prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1],
                plms_timesteps=state._timesteps[:-3][::-1],
            )

        return state.replace(
            timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
            ets=jnp.array([]),
            counter=0,
        )

    def step(
        self,
        state: PNDMSchedulerState,
        model_output: jnp.ndarray,
        timestep: int,
        sample: jnp.ndarray,
        return_dict: bool = True,
    ) -> Union[FlaxSchedulerOutput, Tuple]:
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.

        Args:
205
            state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
            [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
            When returning a tuple, the first element is the sample tensor.

        """
        if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps:
            return self.step_prk(
                state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
            )
        else:
            return self.step_plms(
                state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
            )

    def step_prk(
        self,
        state: PNDMSchedulerState,
        model_output: jnp.ndarray,
        timestep: int,
        sample: jnp.ndarray,
        return_dict: bool = True,
    ) -> Union[FlaxSchedulerOutput, Tuple]:
        """
        Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
        solution to the differential equation.

        Args:
239
            state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
            [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
            When returning a tuple, the first element is the sample tensor.

        """
        if state.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

        diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
        prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1])
        timestep = state.prk_timesteps[state.counter // 4 * 4]

        if state.counter % 4 == 0:
261
            state = state.replace(
262
263
264
265
266
                cur_model_output=state.cur_model_output + 1 / 6 * model_output,
                ets=state.ets.append(model_output),
                cur_sample=sample,
            )
        elif (self.counter - 1) % 4 == 0:
267
            state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
268
        elif (self.counter - 2) % 4 == 0:
269
            state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
270
271
        elif (self.counter - 3) % 4 == 0:
            model_output = state.cur_model_output + 1 / 6 * model_output
272
            state = state.replace(cur_model_output=0)
273
274
275
276
277

        # cur_sample should not be `None`
        cur_sample = state.cur_sample if state.cur_sample is not None else sample

        prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state)
278
        state = state.replace(counter=state.counter + 1)
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

        if not return_dict:
            return (prev_sample, state)

        return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)

    def step_plms(
        self,
        state: PNDMSchedulerState,
        model_output: jnp.ndarray,
        timestep: int,
        sample: jnp.ndarray,
        return_dict: bool = True,
    ) -> Union[FlaxSchedulerOutput, Tuple]:
        """
        Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
        times to approximate the solution.

        Args:
298
            state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
            [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
            When returning a tuple, the first element is the sample tensor.

        """
        if state.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

        if not self.config.skip_prk_steps and len(state.ets) < 3:
            raise ValueError(
                f"{self.__class__} can only be run AFTER scheduler has been run "
                "in 'prk' mode for at least 12 iterations "
                "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
                "for more information."
            )

        prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0)

        if state.counter != 1:
326
            state = state.replace(ets=state.ets.append(model_output))
327
328
329
330
331
332
        else:
            prev_timestep = timestep
            timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps

        if len(state.ets) == 1 and state.counter == 0:
            model_output = model_output
333
            state = state.replace(cur_sample=sample)
334
335
336
        elif len(state.ets) == 1 and state.counter == 1:
            model_output = (model_output + state.ets[-1]) / 2
            sample = state.cur_sample
337
            state = state.replace(cur_sample=None)
338
339
340
341
342
343
344
345
346
347
        elif len(state.ets) == 2:
            model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
        elif len(state.ets) == 3:
            model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
        else:
            model_output = (1 / 24) * (
                55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
            )

        prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state)
348
        state = state.replace(counter=state.counter + 1)
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367

        if not return_dict:
            return (prev_sample, state)

        return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)

    def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state):
        # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
        # this function computes x_(t−δ) using the formula of (9)
        # Note that x_t needs to be added to both sides of the equation

        # Notation (<variable name> -> <name in paper>
        # alpha_prod_t -> α_t
        # alpha_prod_t_prev -> α_(t−δ)
        # beta_prod_t -> (1 - α_t)
        # beta_prod_t_prev -> (1 - α_(t−δ))
        # sample -> x_t
        # model_output -> e_θ(x_t, t)
        # prev_sample -> x_(t−δ)
368
369
        alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset]
        alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset]
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        # corresponds to (α_(t−δ) - α_t) divided by
        # denominator of x_t in formula (9) and plus 1
        # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
        # sqrt(α_(t−δ)) / sqrt(α_t))
        sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)

        # corresponds to denominator of e_θ(x_t, t) in formula (9)
        model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
            alpha_prod_t * beta_prod_t * alpha_prod_t_prev
        ) ** (0.5)

        # full formula (9)
        prev_sample = (
            sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
        )

        return prev_sample

    def add_noise(
        self,
        original_samples: jnp.ndarray,
        noise: jnp.ndarray,
        timesteps: jnp.ndarray,
    ) -> jnp.ndarray:
397
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
398
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
399
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
400
401
402
403
404
405
406
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps