scheduling_dpmsolver_multistep_flax.py 28.2 KB
Newer Older
1
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver

from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
25
from .scheduling_utils_flax import (
26
    CommonSchedulerState,
Kashif Rasul's avatar
Kashif Rasul committed
27
    FlaxKarrasDiffusionSchedulers,
28
29
    FlaxSchedulerMixin,
    FlaxSchedulerOutput,
30
    add_noise_common,
31
)
32
33
34
35


@flax.struct.dataclass
class DPMSolverMultistepSchedulerState:
36
37
38
39
40
    common: CommonSchedulerState
    alpha_t: jnp.ndarray
    sigma_t: jnp.ndarray
    lambda_t: jnp.ndarray

41
    # setable values
42
43
    init_noise_sigma: jnp.ndarray
    timesteps: jnp.ndarray
44
45
46
47
    num_inference_steps: Optional[int] = None

    # running values
    model_outputs: Optional[jnp.ndarray] = None
48
49
    lower_order_nums: Optional[jnp.int32] = None
    prev_timestep: Optional[jnp.int32] = None
50
51
52
    cur_sample: Optional[jnp.ndarray] = None

    @classmethod
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    def create(
        cls,
        common: CommonSchedulerState,
        alpha_t: jnp.ndarray,
        sigma_t: jnp.ndarray,
        lambda_t: jnp.ndarray,
        init_noise_sigma: jnp.ndarray,
        timesteps: jnp.ndarray,
    ):
        return cls(
            common=common,
            alpha_t=alpha_t,
            sigma_t=sigma_t,
            lambda_t=lambda_t,
            init_noise_sigma=init_noise_sigma,
            timesteps=timesteps,
        )
70
71
72
73
74
75
76
77
78
79
80
81
82


@dataclass
class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput):
    state: DPMSolverMultistepSchedulerState


class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
    """
    DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
    the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
    samples, and it can generate quite good samples even in only 10 steps.

Quentin Gallouédec's avatar
Quentin Gallouédec committed
83
84
    For more details, see the original paper: https://huggingface.co/papers/2206.00927 and
    https://huggingface.co/papers/2211.01095
85
86
87
88

    Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
    recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.

Quentin Gallouédec's avatar
Quentin Gallouédec committed
89
90
91
    We also support the "dynamic thresholding" method in Imagen (https://huggingface.co/papers/2205.11487). For
    pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the
    dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
92
93
94
95
    stable-diffusion).

    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
96
97
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.
98

Quentin Gallouédec's avatar
Quentin Gallouédec committed
99
100
    For more details, see the original paper: https://huggingface.co/papers/2206.00927 and
    https://huggingface.co/papers/2211.01095
101
102
103
104
105
106
107
108
109
110
111
112
113

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
        solver_order (`int`, default `2`):
            the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
            sampling, and `solver_order=3` for unconditional sampling.
114
        prediction_type (`str`, default `epsilon`):
115
116
            indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
            or `v-prediction`.
117
        thresholding (`bool`, default `False`):
Quentin Gallouédec's avatar
Quentin Gallouédec committed
118
119
120
121
            whether to use the "dynamic thresholding" method (introduced by Imagen,
            https://huggingface.co/papers/2205.11487). For pixel-space diffusion models, you can set both
            `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the
            thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
122
123
        dynamic_thresholding_ratio (`float`, default `0.995`):
            the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
Quentin Gallouédec's avatar
Quentin Gallouédec committed
124
            (https://huggingface.co/papers/2205.11487).
125
126
127
128
129
        sample_max_value (`float`, default `1.0`):
            the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
            `algorithm_type="dpmsolver++`.
        algorithm_type (`str`, default `dpmsolver++`):
            the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
Quentin Gallouédec's avatar
Quentin Gallouédec committed
130
131
132
            algorithms in https://huggingface.co/papers/2206.00927, and the `dpmsolver++` type implements the
            algorithms in https://huggingface.co/papers/2211.01095. We recommend to use `dpmsolver++` with
            `solver_order=2` for guided sampling (e.g. stable-diffusion).
133
134
135
136
137
138
139
        solver_type (`str`, default `midpoint`):
            the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
            the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
            slightly better, so we recommend to use the `midpoint` type.
        lower_order_final (`bool`, default `True`):
            whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
            find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
140
141
142
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
143
144
        dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
            the `dtype` used for params and computation.
145
146
    """

Kashif Rasul's avatar
Kashif Rasul committed
147
    _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
148

149
150
    dtype: jnp.dtype

151
152
153
154
155
156
157
158
159
160
161
162
163
    @property
    def has_state(self):
        return True

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[jnp.ndarray] = None,
        solver_order: int = 2,
164
        prediction_type: str = "epsilon",
165
166
167
168
169
170
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        algorithm_type: str = "dpmsolver++",
        solver_type: str = "midpoint",
        lower_order_final: bool = True,
171
        timestep_spacing: str = "linspace",
172
        dtype: jnp.dtype = jnp.float32,
173
    ):
174
175
176
177
178
        self.dtype = dtype

    def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
        if common is None:
            common = CommonSchedulerState.create(self)
179
180

        # Currently we only support VP-type noise schedule
181
182
183
184
185
186
        alpha_t = jnp.sqrt(common.alphas_cumprod)
        sigma_t = jnp.sqrt(1 - common.alphas_cumprod)
        lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t)

        # settings for DPM-Solver
        if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]:
187
            raise NotImplementedError(f"{self.config.algorithm_type} is not implemented for {self.__class__}")
188
        if self.config.solver_type not in ["midpoint", "heun"]:
189
            raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}")
190
191

        # standard deviation of the initial noise distribution
192
        init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
193

194
        timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
195

196
197
198
199
200
201
202
203
        return DPMSolverMultistepSchedulerState.create(
            common=common,
            alpha_t=alpha_t,
            sigma_t=sigma_t,
            lambda_t=lambda_t,
            init_noise_sigma=init_noise_sigma,
            timesteps=timesteps,
        )
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    def set_timesteps(
        self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
    ) -> DPMSolverMultistepSchedulerState:
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            state (`DPMSolverMultistepSchedulerState`):
                the `FlaxDPMSolverMultistepScheduler` state data class instance.
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            shape (`Tuple`):
                the shape of the samples to be generated.
        """
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        last_timestep = self.config.num_train_timesteps
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32)
            )
        elif self.config.timestep_spacing == "leading":
            step_ratio = last_timestep // (num_inference_steps + 1)
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (
                (jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32)
            )
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )
242

243
244
245
246
247
248
249
        # initial running values

        model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype)
        lower_order_nums = jnp.int32(0)
        prev_timestep = jnp.int32(-1)
        cur_sample = jnp.zeros(shape, dtype=self.dtype)

250
251
252
        return state.replace(
            num_inference_steps=num_inference_steps,
            timesteps=timesteps,
253
254
255
256
            model_outputs=model_outputs,
            lower_order_nums=lower_order_nums,
            prev_timestep=prev_timestep,
            cur_sample=cur_sample,
257
258
259
260
        )

    def convert_model_output(
        self,
261
        state: DPMSolverMultistepSchedulerState,
262
263
264
265
266
267
268
        model_output: jnp.ndarray,
        timestep: int,
        sample: jnp.ndarray,
    ) -> jnp.ndarray:
        """
        Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.

269
        DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        discretize an integral of the data prediction model. So we need to first convert the model output to the
        corresponding type to match the algorithm.

        Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
        DPM-Solver++ for both noise prediction model and data prediction model.

        Args:
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.

        Returns:
            `jnp.ndarray`: the converted model output.
        """
        # DPM-Solver++ needs to solve an integral of the data prediction model.
        if self.config.algorithm_type == "dpmsolver++":
287
            if self.config.prediction_type == "epsilon":
288
                alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
289
                x0_pred = (sample - sigma_t * model_output) / alpha_t
290
            elif self.config.prediction_type == "sample":
291
                x0_pred = model_output
292
            elif self.config.prediction_type == "v_prediction":
293
                alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
294
                x0_pred = alpha_t * sample - sigma_t * model_output
295
296
            else:
                raise ValueError(
297
298
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
                    " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
299
300
                )

301
            if self.config.thresholding:
Quentin Gallouédec's avatar
Quentin Gallouédec committed
302
                # Dynamic thresholding in https://huggingface.co/papers/2205.11487
303
304
305
306
307
308
309
310
311
312
                dynamic_max_val = jnp.percentile(
                    jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
                )
                dynamic_max_val = jnp.maximum(
                    dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
                )
                x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
            return x0_pred
        # DPM-Solver needs to solve an integral of the noise prediction model.
        elif self.config.algorithm_type == "dpmsolver":
313
            if self.config.prediction_type == "epsilon":
314
                return model_output
315
            elif self.config.prediction_type == "sample":
316
                alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
317
318
                epsilon = (sample - alpha_t * model_output) / sigma_t
                return epsilon
319
            elif self.config.prediction_type == "v_prediction":
320
                alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
321
322
                epsilon = alpha_t * model_output + sigma_t * sample
                return epsilon
323
324
            else:
                raise ValueError(
325
326
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
                    " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
327
                )
328
329

    def dpm_solver_first_order_update(
330
331
332
333
334
335
        self,
        state: DPMSolverMultistepSchedulerState,
        model_output: jnp.ndarray,
        timestep: int,
        prev_timestep: int,
        sample: jnp.ndarray,
336
337
338
339
    ) -> jnp.ndarray:
        """
        One step for the first-order DPM-Solver (equivalent to DDIM).

Quentin Gallouédec's avatar
Quentin Gallouédec committed
340
        See https://huggingface.co/papers/2206.00927 for the detailed derivation.
341
342
343
344
345
346
347
348
349
350
351
352
353

        Args:
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            prev_timestep (`int`): previous discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.

        Returns:
            `jnp.ndarray`: the sample tensor at the previous timestep.
        """
        t, s0 = prev_timestep, timestep
        m0 = model_output
354
355
356
        lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0]
        alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0]
        sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0]
357
358
359
360
361
362
363
364
365
        h = lambda_t - lambda_s
        if self.config.algorithm_type == "dpmsolver++":
            x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
        elif self.config.algorithm_type == "dpmsolver":
            x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0
        return x_t

    def multistep_dpm_solver_second_order_update(
        self,
366
        state: DPMSolverMultistepSchedulerState,
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        model_output_list: jnp.ndarray,
        timestep_list: List[int],
        prev_timestep: int,
        sample: jnp.ndarray,
    ) -> jnp.ndarray:
        """
        One step for the second-order multistep DPM-Solver.

        Args:
            model_output_list (`List[jnp.ndarray]`):
                direct outputs from learned diffusion model at current and latter timesteps.
            timestep (`int`): current and latter discrete timestep in the diffusion chain.
            prev_timestep (`int`): previous discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.

        Returns:
            `jnp.ndarray`: the sample tensor at the previous timestep.
        """
        t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
        m0, m1 = model_output_list[-1], model_output_list[-2]
388
389
390
        lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
        alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
        sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
391
392
393
394
        h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
        r0 = h_0 / h
        D0, D1 = m0, (1.0 / r0) * (m0 - m1)
        if self.config.algorithm_type == "dpmsolver++":
Quentin Gallouédec's avatar
Quentin Gallouédec committed
395
            # See https://huggingface.co/papers/2211.01095 for detailed derivations
396
397
398
399
400
401
402
403
404
405
406
407
408
            if self.config.solver_type == "midpoint":
                x_t = (
                    (sigma_t / sigma_s0) * sample
                    - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
                    - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (sigma_t / sigma_s0) * sample
                    - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
                    + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
                )
        elif self.config.algorithm_type == "dpmsolver":
Quentin Gallouédec's avatar
Quentin Gallouédec committed
409
            # See https://huggingface.co/papers/2206.00927 for detailed derivations
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
            if self.config.solver_type == "midpoint":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - (sigma_t * (jnp.exp(h) - 1.0)) * D0
                    - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - (sigma_t * (jnp.exp(h) - 1.0)) * D0
                    - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
                )
        return x_t

    def multistep_dpm_solver_third_order_update(
        self,
426
        state: DPMSolverMultistepSchedulerState,
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        model_output_list: jnp.ndarray,
        timestep_list: List[int],
        prev_timestep: int,
        sample: jnp.ndarray,
    ) -> jnp.ndarray:
        """
        One step for the third-order multistep DPM-Solver.

        Args:
            model_output_list (`List[jnp.ndarray]`):
                direct outputs from learned diffusion model at current and latter timesteps.
            timestep (`int`): current and latter discrete timestep in the diffusion chain.
            prev_timestep (`int`): previous discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.

        Returns:
            `jnp.ndarray`: the sample tensor at the previous timestep.
        """
        t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
        m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
        lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
449
450
451
452
            state.lambda_t[t],
            state.lambda_t[s0],
            state.lambda_t[s1],
            state.lambda_t[s2],
453
        )
454
455
        alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
        sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
456
457
458
459
460
461
462
        h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
        r0, r1 = h_0 / h, h_1 / h
        D0 = m0
        D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
        D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
        D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
        if self.config.algorithm_type == "dpmsolver++":
Quentin Gallouédec's avatar
Quentin Gallouédec committed
463
            # See https://huggingface.co/papers/2206.00927 for detailed derivations
464
465
466
467
468
469
470
            x_t = (
                (sigma_t / sigma_s0) * sample
                - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
                + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
                - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
            )
        elif self.config.algorithm_type == "dpmsolver":
Quentin Gallouédec's avatar
Quentin Gallouédec committed
471
            # See https://huggingface.co/papers/2206.00927 for detailed derivations
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
            x_t = (
                (alpha_t / alpha_s0) * sample
                - (sigma_t * (jnp.exp(h) - 1.0)) * D0
                - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
                - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
            )
        return x_t

    def step(
        self,
        state: DPMSolverMultistepSchedulerState,
        model_output: jnp.ndarray,
        timestep: int,
        sample: jnp.ndarray,
        return_dict: bool = True,
    ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]:
        """
        Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process
        from the learned model outputs (most often the predicted noise).

        Args:
            state (`DPMSolverMultistepSchedulerState`):
                the `FlaxDPMSolverMultistepScheduler` state data class instance.
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
            return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class

        Returns:
            [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if
            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.

        """
506
507
508
509
        if state.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )
510

511
512
513
514
515
516
        (step_index,) = jnp.where(state.timesteps == timestep, size=1)
        step_index = step_index[0]

        prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1])

        model_output = self.convert_model_output(state, model_output, timestep, sample)
517
518
519
520
521
522
523
524
525
526
527

        model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0)
        model_outputs_new = model_outputs_new.at[-1].set(model_output)
        state = state.replace(
            model_outputs=model_outputs_new,
            prev_timestep=prev_timestep,
            cur_sample=sample,
        )

        def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
            return self.dpm_solver_first_order_update(
528
                state,
529
                state.model_outputs[-1],
530
                state.timesteps[step_index],
531
532
533
534
535
536
                state.prev_timestep,
                state.cur_sample,
            )

        def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
            def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
537
                timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]])
538
                return self.multistep_dpm_solver_second_order_update(
539
                    state,
540
541
542
543
544
545
546
547
548
                    state.model_outputs,
                    timestep_list,
                    state.prev_timestep,
                    state.cur_sample,
                )

            def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
                timestep_list = jnp.array(
                    [
549
550
551
                        state.timesteps[step_index - 2],
                        state.timesteps[step_index - 1],
                        state.timesteps[step_index],
552
553
554
                    ]
                )
                return self.multistep_dpm_solver_third_order_update(
555
                    state,
556
557
558
559
560
561
                    state.model_outputs,
                    timestep_list,
                    state.prev_timestep,
                    state.cur_sample,
                )

562
563
564
            step_2_output = step_2(state)
            step_3_output = step_3(state)

565
            if self.config.solver_order == 2:
566
                return step_2_output
567
            elif self.config.lower_order_final and len(state.timesteps) < 15:
568
                return jax.lax.select(
569
                    state.lower_order_nums < 2,
570
571
572
573
574
                    step_2_output,
                    jax.lax.select(
                        step_index == len(state.timesteps) - 2,
                        step_2_output,
                        step_3_output,
575
576
577
                    ),
                )
            else:
578
                return jax.lax.select(
579
                    state.lower_order_nums < 2,
580
581
                    step_2_output,
                    step_3_output,
582
583
                )

584
585
586
        step_1_output = step_1(state)
        step_23_output = step_23(state)

587
        if self.config.solver_order == 1:
588
589
            prev_sample = step_1_output

590
        elif self.config.lower_order_final and len(state.timesteps) < 15:
591
            prev_sample = jax.lax.select(
592
                state.lower_order_nums < 1,
593
594
595
596
597
                step_1_output,
                jax.lax.select(
                    step_index == len(state.timesteps) - 1,
                    step_1_output,
                    step_23_output,
598
599
                ),
            )
600

601
        else:
602
            prev_sample = jax.lax.select(
603
                state.lower_order_nums < 1,
604
605
                step_1_output,
                step_23_output,
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
            )

        state = state.replace(
            lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order),
        )

        if not return_dict:
            return (prev_sample, state)

        return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)

    def scale_model_input(
        self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
    ) -> jnp.ndarray:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            state (`DPMSolverMultistepSchedulerState`):
                the `FlaxDPMSolverMultistepScheduler` state data class instance.
            sample (`jnp.ndarray`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `jnp.ndarray`: scaled input sample
        """
        return sample

    def add_noise(
        self,
637
        state: DPMSolverMultistepSchedulerState,
638
639
640
641
        original_samples: jnp.ndarray,
        noise: jnp.ndarray,
        timesteps: jnp.ndarray,
    ) -> jnp.ndarray:
642
        return add_noise_common(state.common, original_samples, noise, timesteps)
643
644
645

    def __len__(self):
        return self.config.num_train_timesteps