"llama/llama.cpp/vscode:/vscode.git/clone" did not exist on "854d40edc5c5894014a9aea28fcca7b7aeba83bb"
scheduling_lms_discrete_flax.py 8.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple, Union

import flax
import jax.numpy as jnp
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
23
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39


@flax.struct.dataclass
class LMSDiscreteSchedulerState:
    # setable values
    num_inference_steps: Optional[int] = None
    timesteps: Optional[jnp.ndarray] = None
    sigmas: Optional[jnp.ndarray] = None
    derivatives: jnp.ndarray = jnp.array([])

    @classmethod
    def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray):
        return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], sigmas=sigmas)


@dataclass
40
class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
41
42
43
    state: LMSDiscreteSchedulerState


44
class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    """
    Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
    Katherine Crowson:
    https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181

    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
    [`~ConfigMixin.from_config`] functions.

    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear` or `scaled_linear`.
        trained_betas (`jnp.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
    """

66
67
68
69
    @property
    def has_state(self):
        return True

70
71
72
73
74
75
76
77
78
79
80
    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[jnp.ndarray] = None,
    ):
        if trained_betas is not None:
            self.betas = jnp.asarray(trained_betas)
81
        elif beta_schedule == "linear":
82
83
84
85
86
87
88
89
90
91
            self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
            self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)

92
    def create_state(self):
93
        self.state = LMSDiscreteSchedulerState.create(
94
95
            num_train_timesteps=self.config.num_train_timesteps,
            sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        )

    def get_lms_coefficient(self, state, order, t, current_order):
        """
        Compute a linear multistep coefficient.

        Args:
            order (TODO):
            t (TODO):
            current_order (TODO):
        """

        def lms_derivative(tau):
            prod = 1.0
            for k in range(order):
                if current_order == k:
                    continue
                prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
            return prod

        integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]

        return integrated_coeff

120
    def set_timesteps(
121
        self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
122
    ) -> LMSDiscreteSchedulerState:
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        """
        Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            state (`LMSDiscreteSchedulerState`):
                the `FlaxLMSDiscreteScheduler` state data class instance.
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
        """
        timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=jnp.float32)

        low_idx = jnp.floor(timesteps).astype(int)
        high_idx = jnp.ceil(timesteps).astype(int)
        frac = jnp.mod(timesteps, 1.0)
        sigmas = jnp.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
        sigmas = jnp.concatenate([sigmas, jnp.array([0.0])]).astype(jnp.float32)

        return state.replace(
            num_inference_steps=num_inference_steps,
143
            timesteps=timesteps.astype(int),
144
145
146
147
148
149
150
151
152
153
154
155
            derivatives=jnp.array([]),
            sigmas=sigmas,
        )

    def step(
        self,
        state: LMSDiscreteSchedulerState,
        model_output: jnp.ndarray,
        timestep: int,
        sample: jnp.ndarray,
        order: int = 4,
        return_dict: bool = True,
156
    ) -> Union[FlaxLMSSchedulerOutput, Tuple]:
157
158
159
160
161
162
163
164
165
166
167
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance.
            model_output (`jnp.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`jnp.ndarray`):
                current instance of sample being created by diffusion process.
            order: coefficient for multi-step inference.
168
            return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
169
170

        Returns:
171
172
            [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
            `tuple`. When returning a tuple, the first element is the sample tensor.
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

        """
        sigma = state.sigmas[timestep]

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
        pred_original_sample = sample - sigma * model_output

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma
        state = state.replace(derivatives=state.derivatives.append(derivative))
        if len(state.derivatives) > order:
            state = state.replace(derivatives=state.derivatives.pop(0))

        # 3. Compute linear multistep coefficients
        order = min(timestep + 1, order)
        lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]

        # 4. Compute previous sample based on the derivatives path
        prev_sample = sample + sum(
            coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
        )

        if not return_dict:
            return (prev_sample, state)

198
        return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
199
200
201
202
203
204
205
206

    def add_noise(
        self,
        state: LMSDiscreteSchedulerState,
        original_samples: jnp.ndarray,
        noise: jnp.ndarray,
        timesteps: jnp.ndarray,
    ) -> jnp.ndarray:
207
        sigma = state.sigmas[timesteps].flatten()
208
        sigma = broadcast_to_shape_from_left(sigma, noise.shape)
209
210

        noisy_samples = original_samples + noise * sigma
211
212
213
214
215

        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps