scheduling_sde_ve_flax.py 11.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

from dataclasses import dataclass
from typing import Optional, Tuple, Union

import flax
import jax.numpy as jnp
from jax import random

from ..configuration_utils import ConfigMixin, register_to_config
25
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


@flax.struct.dataclass
class ScoreSdeVeSchedulerState:
    # setable values
    timesteps: Optional[jnp.ndarray] = None
    discrete_sigmas: Optional[jnp.ndarray] = None
    sigmas: Optional[jnp.ndarray] = None

    @classmethod
    def create(cls):
        return cls()


@dataclass
41
class FlaxSdeVeOutput(FlaxSchedulerOutput):
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    """
    Output class for the ScoreSdeVeScheduler's step function output.

    Args:
        state (`ScoreSdeVeSchedulerState`):
        prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
            Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
    """

    state: ScoreSdeVeSchedulerState
    prev_sample: jnp.ndarray
    prev_sample_mean: Optional[jnp.ndarray] = None


59
class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
60
61
62
63
64
65
66
    """
    The variance exploding stochastic differential equation (SDE) scheduler.

    For more information, see the original paper: https://arxiv.org/abs/2011.13456

    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
67
68
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        snr (`float`):
            coefficient weighting the step from the model_output sample (from the network) to the random noise.
        sigma_min (`float`):
                initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
                distribution of the data.
        sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
        sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
        epsilon.
        correct_steps (`int`): number of correction steps performed on a produced sample.
    """

83
84
85
86
    @property
    def has_state(self):
        return True

87
88
89
90
91
92
93
94
95
96
    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 2000,
        snr: float = 0.15,
        sigma_min: float = 0.01,
        sigma_max: float = 1348.0,
        sampling_eps: float = 1e-5,
        correct_steps: int = 1,
    ):
97
        pass
98

99
100
101
102
103
104
105
106
107
    def create_state(self):
        state = ScoreSdeVeSchedulerState.create()
        return self.set_sigmas(
            state,
            self.config.num_train_timesteps,
            self.config.sigma_min,
            self.config.sigma_max,
            self.config.sampling_eps,
        )
108
109

    def set_timesteps(
110
        self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    ) -> ScoreSdeVeSchedulerState:
        """
        Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).

        """
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps

        timesteps = jnp.linspace(1, sampling_eps, num_inference_steps)
        return state.replace(timesteps=timesteps)

    def set_sigmas(
        self,
        state: ScoreSdeVeSchedulerState,
        num_inference_steps: int,
        sigma_min: float = None,
        sigma_max: float = None,
        sampling_eps: float = None,
    ) -> ScoreSdeVeSchedulerState:
        """
        Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.

        The sigmas control the weight of the `drift` and `diffusion` components of sample update.

        Args:
            state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sigma_min (`float`, optional):
                initial noise scale value (overrides value given at Scheduler instantiation).
            sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
        """
        sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
        sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
        if state.timesteps is None:
            state = self.set_timesteps(state, num_inference_steps, sampling_eps)

        discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps))
        sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps])

        return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas)

    def get_adjacent_sigma(self, state, timesteps, t):
        return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1])

    def step_pred(
        self,
        state: ScoreSdeVeSchedulerState,
        model_output: jnp.ndarray,
        timestep: int,
        sample: jnp.ndarray,
        key: random.KeyArray,
        return_dict: bool = True,
    ) -> Union[FlaxSdeVeOutput, Tuple]:
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
            generator: random number generator.
183
            return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

        Returns:
            [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.

        """
        if state.timesteps is None:
            raise ValueError(
                "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

        timestep = timestep * jnp.ones(
            sample.shape[0],
        )
        timesteps = (timestep * (len(state.timesteps) - 1)).long()

        sigma = state.discrete_sigmas[timesteps]
        adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep)
        drift = jnp.zeros_like(sample)
        diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

        # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
        # also equation 47 shows the analog from SDE models to ancestral sampling methods
207
        diffusion = diffusion.flatten()
208
        diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
209
        drift = drift - diffusion**2 * model_output
210
211
212
213
214
215

        #  equation 6: sample noise for the diffusion term of
        key = random.split(key, num=1)
        noise = random.normal(key=key, shape=sample.shape)
        prev_sample_mean = sample - drift  # subtract because `dt` is a small negative timestep
        # TODO is the variable diffusion the correct scaling term for the noise?
216
        prev_sample = prev_sample_mean + diffusion * noise  # add impact of diffusion field g
217
218
219
220
221
222
223
224
225
226
227
228
229

        if not return_dict:
            return (prev_sample, prev_sample_mean, state)

        return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state)

    def step_correct(
        self,
        state: ScoreSdeVeSchedulerState,
        model_output: jnp.ndarray,
        sample: jnp.ndarray,
        key: random.KeyArray,
        return_dict: bool = True,
230
    ) -> Union[FlaxSdeVeOutput, Tuple]:
231
232
233
234
235
236
237
238
239
240
        """
        Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
        after making the prediction for the previous timestep.

        Args:
            state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
            generator: random number generator.
241
            return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

        Returns:
            [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.

        """
        if state.timesteps is None:
            raise ValueError(
                "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

        # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
        # sample noise for correction
        key = random.split(key, num=1)
        noise = random.normal(key=key, shape=sample.shape)

        # compute step size from the model_output, the noise, and the snr
        grad_norm = jnp.linalg.norm(model_output)
        noise_norm = jnp.linalg.norm(noise)
        step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
        step_size = step_size * jnp.ones(sample.shape[0])

        # compute corrected sample: model_output term and noise term
265
        step_size = step_size.flatten()
266
        step_size = broadcast_to_shape_from_left(step_size, sample.shape)
267
268
        prev_sample_mean = sample + step_size * model_output
        prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
269
270
271
272
273
274
275
276

        if not return_dict:
            return (prev_sample, state)

        return FlaxSdeVeOutput(prev_sample=prev_sample, state=state)

    def __len__(self):
        return self.config.num_train_timesteps