scheduling_ddim_flax.py 12.3 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

from dataclasses import dataclass
from typing import Optional, Tuple, Union

import flax
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
25
from .scheduling_utils_flax import (
26
    CommonSchedulerState,
Kashif Rasul's avatar
Kashif Rasul committed
27
    FlaxKarrasDiffusionSchedulers,
28
29
    FlaxSchedulerMixin,
    FlaxSchedulerOutput,
30
    add_noise_common,
31
    get_velocity_common,
32
)
33
34
35
36


@flax.struct.dataclass
class DDIMSchedulerState:
37
38
39
    common: CommonSchedulerState
    final_alpha_cumprod: jnp.ndarray

40
    # setable values
41
    init_noise_sigma: jnp.ndarray
42
43
44
45
    timesteps: jnp.ndarray
    num_inference_steps: Optional[int] = None

    @classmethod
46
47
48
49
50
51
52
53
54
55
56
57
58
    def create(
        cls,
        common: CommonSchedulerState,
        final_alpha_cumprod: jnp.ndarray,
        init_noise_sigma: jnp.ndarray,
        timesteps: jnp.ndarray,
    ):
        return cls(
            common=common,
            final_alpha_cumprod=final_alpha_cumprod,
            init_noise_sigma=init_noise_sigma,
            timesteps=timesteps,
        )
59
60
61


@dataclass
62
class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
63
64
65
    state: DDIMSchedulerState


66
class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
67
68
69
70
71
72
    """
    Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
    diffusion probabilistic models (DDPMs) with non-Markovian guidance.

    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
73
74
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    For more details, see the original paper: https://arxiv.org/abs/2010.02502

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
        trained_betas (`jnp.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
        clip_sample (`bool`, default `True`):
            option to clip predicted sample between -1 and 1 for numerical stability.
        set_alpha_to_one (`bool`, default `True`):
90
91
92
93
94
95
96
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        steps_offset (`int`, default `0`):
            an offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
97
98
99
        prediction_type (`str`, default `epsilon`):
            indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
            `v-prediction` is not supported for this scheduler.
100
101
        dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
            the `dtype` used for params and computation.
102
103
    """

Kashif Rasul's avatar
Kashif Rasul committed
104
    _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
105

106
107
    dtype: jnp.dtype

108
109
110
111
    @property
    def has_state(self):
        return True

112
113
114
115
116
117
118
    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
119
        trained_betas: Optional[jnp.ndarray] = None,
120
        set_alpha_to_one: bool = True,
121
        steps_offset: int = 0,
122
        prediction_type: str = "epsilon",
123
        dtype: jnp.dtype = jnp.float32,
124
    ):
125
        self.dtype = dtype
126

127
128
129
    def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
        if common is None:
            common = CommonSchedulerState.create(self)
130
131
132
133
134

        # At every step in ddim, we are looking into the previous alphas_cumprod
        # For the final step, there is no previous alphas_cumprod because we are already at 0
        # `set_alpha_to_one` decides whether we set this parameter simply to one or
        # whether we use the final alpha of the "non-previous" one.
135
136
137
        final_alpha_cumprod = (
            jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
        )
138

Suraj Patil's avatar
Suraj Patil committed
139
        # standard deviation of the initial noise distribution
140
141
142
143
144
145
146
147
148
149
        init_noise_sigma = jnp.array(1.0, dtype=self.dtype)

        timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]

        return DDIMSchedulerState.create(
            common=common,
            final_alpha_cumprod=final_alpha_cumprod,
            init_noise_sigma=init_noise_sigma,
            timesteps=timesteps,
        )
Suraj Patil's avatar
Suraj Patil committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    def scale_model_input(
        self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
    ) -> jnp.ndarray:
        """
        Args:
            state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
            sample (`jnp.ndarray`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `jnp.ndarray`: scaled input sample
        """
        return sample

165
166
167
    def set_timesteps(
        self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
    ) -> DDIMSchedulerState:
168
169
170
171
172
173
174
175
176
177
178
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            state (`DDIMSchedulerState`):
                the `FlaxDDIMScheduler` state data class instance.
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
        step_ratio = self.config.num_train_timesteps // num_inference_steps
        # creates integer timesteps by multiplying by ratio
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        # rounding to avoid issues when num_inference_step is power of 3
        timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset

        return state.replace(
            num_inference_steps=num_inference_steps,
            timesteps=timesteps,
        )

    def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
        alpha_prod_t = state.common.alphas_cumprod[timestep]
        alpha_prod_t_prev = jnp.where(
            prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
        )
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
196

197
        return variance
198
199
200
201
202
203
204

    def step(
        self,
        state: DDIMSchedulerState,
        model_output: jnp.ndarray,
        timestep: int,
        sample: jnp.ndarray,
205
        eta: float = 0.0,
206
        return_dict: bool = True,
207
    ) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
208
209
210
211
212
213
214
215
216
217
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance.
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
218
            return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
219
220

        Returns:
221
222
            [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
            `tuple`. When returning a tuple, the first element is the sample tensor.
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

        """
        if state.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
        # - pred_original_sample -> f_theta(x_t, t) or x_0
        # - std_dev_t -> sigma_t
        # - eta -> η
        # - pred_sample_direction -> "direction pointing to x_t"
        # - pred_prev_sample -> "x_t-1"

        # 1. get previous step value (=t-1)
        prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps

244
245
        alphas_cumprod = state.common.alphas_cumprod
        final_alpha_cumprod = state.final_alpha_cumprod
246

247
        # 2. compute alphas, betas
248
        alpha_prod_t = alphas_cumprod[timestep]
249
        alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod)
250

251
252
253
254
        beta_prod_t = 1 - alpha_prod_t

        # 3. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
255
256
257
258
259
260
261
262
263
264
265
266
267
        if self.config.prediction_type == "epsilon":
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
        elif self.config.prediction_type == "v_prediction":
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
            # predict V
            model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                " `v_prediction`"
            )
268

269
        # 4. compute variance: "sigma_t(η)" -> see formula (16)
270
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
271
        variance = self._get_variance(state, timestep, prev_timestep)
272
        std_dev_t = eta * variance ** (0.5)
273

274
        # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
275
276
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output

277
        # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
278
279
280
281
282
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if not return_dict:
            return (prev_sample, state)

283
        return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
284
285
286

    def add_noise(
        self,
287
        state: DDIMSchedulerState,
288
289
290
291
        original_samples: jnp.ndarray,
        noise: jnp.ndarray,
        timesteps: jnp.ndarray,
    ) -> jnp.ndarray:
292
        return add_noise_common(state.common, original_samples, noise, timesteps)
293

294
295
296
297
298
299
300
301
302
    def get_velocity(
        self,
        state: DDIMSchedulerState,
        sample: jnp.ndarray,
        noise: jnp.ndarray,
        timesteps: jnp.ndarray,
    ) -> jnp.ndarray:
        return get_velocity_common(state.common, sample, noise, timesteps)

303
304
    def __len__(self):
        return self.config.num_train_timesteps