scheduling_lms_discrete_flax.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple, Union

import flax
import jax.numpy as jnp
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
23
from .scheduling_utils_flax import (
24
    CommonSchedulerState,
Kashif Rasul's avatar
Kashif Rasul committed
25
    FlaxKarrasDiffusionSchedulers,
26
27
28
29
    FlaxSchedulerMixin,
    FlaxSchedulerOutput,
    broadcast_to_shape_from_left,
)
30
31
32
33


@flax.struct.dataclass
class LMSDiscreteSchedulerState:
34
35
    common: CommonSchedulerState

36
    # setable values
37
38
39
    init_noise_sigma: jnp.ndarray
    timesteps: jnp.ndarray
    sigmas: jnp.ndarray
40
    num_inference_steps: Optional[int] = None
41
42
43

    # running values
    derivatives: Optional[jnp.ndarray] = None
44
45

    @classmethod
46
47
48
49
    def create(
        cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
    ):
        return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
50
51
52


@dataclass
53
class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
54
55
56
    state: LMSDiscreteSchedulerState


57
class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
58
59
60
61
62
63
64
    """
    Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
    Katherine Crowson:
    https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181

    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
65
66
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.
67
68
69
70
71
72
73
74
75
76

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear` or `scaled_linear`.
        trained_betas (`jnp.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
77
78
79
80
81
82
        prediction_type (`str`, default `epsilon`, optional):
            prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
            process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
            https://imagen.research.google/video/paper.pdf)
        dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
            the `dtype` used for params and computation.
83
84
    """

Kashif Rasul's avatar
Kashif Rasul committed
85
    _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
86

87
88
    dtype: jnp.dtype

89
90
91
92
    @property
    def has_state(self):
        return True

93
94
95
96
97
98
99
100
    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[jnp.ndarray] = None,
101
102
        prediction_type: str = "epsilon",
        dtype: jnp.dtype = jnp.float32,
103
    ):
104
        self.dtype = dtype
105

106
107
108
    def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
        if common is None:
            common = CommonSchedulerState.create(self)
109

110
111
112
113
114
115
116
117
118
119
120
        timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
        sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5

        # standard deviation of the initial noise distribution
        init_noise_sigma = sigmas.max()

        return LMSDiscreteSchedulerState.create(
            common=common,
            init_noise_sigma=init_noise_sigma,
            timesteps=timesteps,
            sigmas=sigmas,
121
122
        )

Patrick von Platen's avatar
Patrick von Platen committed
123
    def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        """
        Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.

        Args:
            state (`LMSDiscreteSchedulerState`):
                the `FlaxLMSDiscreteScheduler` state data class instance.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
            timestep (`int`):
                current discrete timestep in the diffusion chain.

        Returns:
            `jnp.ndarray`: scaled input sample
        """
Patrick von Platen's avatar
Patrick von Platen committed
138
        (step_index,) = jnp.where(state.timesteps == timestep, size=1)
139
140
        step_index = step_index[0]

Patrick von Platen's avatar
Patrick von Platen committed
141
        sigma = state.sigmas[step_index]
142
143
144
        sample = sample / ((sigma**2 + 1) ** 0.5)
        return sample

145
    def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order):
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        """
        Compute a linear multistep coefficient.

        Args:
            order (TODO):
            t (TODO):
            current_order (TODO):
        """

        def lms_derivative(tau):
            prod = 1.0
            for k in range(order):
                if current_order == k:
                    continue
                prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
            return prod

        integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]

        return integrated_coeff

167
    def set_timesteps(
168
        self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
169
    ) -> LMSDiscreteSchedulerState:
170
171
172
173
174
175
176
177
178
179
        """
        Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            state (`LMSDiscreteSchedulerState`):
                the `FlaxLMSDiscreteScheduler` state data class instance.
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """

180
181
182
183
184
        timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)

        low_idx = jnp.floor(timesteps).astype(jnp.int32)
        high_idx = jnp.ceil(timesteps).astype(jnp.int32)

185
        frac = jnp.mod(timesteps, 1.0)
186
187

        sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
188
        sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
189
190
191
192
193
194
        sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])

        timesteps = timesteps.astype(jnp.int32)

        # initial running values
        derivatives = jnp.zeros((0,) + shape, dtype=self.dtype)
195
196

        return state.replace(
197
            timesteps=timesteps,
198
            sigmas=sigmas,
199
200
            num_inference_steps=num_inference_steps,
            derivatives=derivatives,
201
202
203
204
205
206
207
208
209
210
        )

    def step(
        self,
        state: LMSDiscreteSchedulerState,
        model_output: jnp.ndarray,
        timestep: int,
        sample: jnp.ndarray,
        order: int = 4,
        return_dict: bool = True,
211
    ) -> Union[FlaxLMSSchedulerOutput, Tuple]:
212
213
214
215
216
217
218
219
220
221
222
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance.
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
            order: coefficient for multi-step inference.
223
            return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
224
225

        Returns:
226
227
            [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
            `tuple`. When returning a tuple, the first element is the sample tensor.
228
229

        """
230
231
232
233
234
        if state.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

235
236
237
        sigma = state.sigmas[timestep]

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
238
239
240
241
242
243
244
245
246
        if self.config.prediction_type == "epsilon":
            pred_original_sample = sample - sigma * model_output
        elif self.config.prediction_type == "v_prediction":
            # * c_out + input * c_skip
            pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )
247
248
249

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma
250
        state = state.replace(derivatives=jnp.append(state.derivatives, derivative))
251
        if len(state.derivatives) > order:
252
            state = state.replace(derivatives=jnp.delete(state.derivatives, 0))
253
254
255
256
257
258
259
260
261
262
263
264
265

        # 3. Compute linear multistep coefficients
        order = min(timestep + 1, order)
        lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]

        # 4. Compute previous sample based on the derivatives path
        prev_sample = sample + sum(
            coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
        )

        if not return_dict:
            return (prev_sample, state)

266
        return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
267
268
269
270
271
272
273
274

    def add_noise(
        self,
        state: LMSDiscreteSchedulerState,
        original_samples: jnp.ndarray,
        noise: jnp.ndarray,
        timesteps: jnp.ndarray,
    ) -> jnp.ndarray:
275
        sigma = state.sigmas[timesteps].flatten()
276
        sigma = broadcast_to_shape_from_left(sigma, noise.shape)
277
278

        noisy_samples = original_samples + noise * sigma
279
280
281
282
283

        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps