scheduling_lms_discrete.py 11.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import warnings
15
from dataclasses import dataclass
16
from typing import Optional, Tuple, Union
17
18
19
20
21
22
23

import numpy as np
import torch

from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
24
from ..utils import BaseOutput, deprecate
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from .scheduling_utils import SchedulerMixin


@dataclass
class LMSDiscreteSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
44
45
46


class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
47
48
49
50
51
    """
    Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
    Katherine Crowson:
    https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181

52
53
54
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
55
    [`~ConfigMixin.from_config`] functions.
56

57
58
59
60
61
62
63
    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear` or `scaled_linear`.
Nathan Lambert's avatar
Nathan Lambert committed
64
65
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
66
67
68

    """

69
70
71
    @register_to_config
    def __init__(
        self,
72
73
74
75
76
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
77
    ):
78
        if trained_betas is not None:
79
            self.betas = torch.from_numpy(trained_betas)
80
        elif beta_schedule == "linear":
81
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
82
83
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
84
85
86
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
87
88
89
90
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
91
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
92

93
94
95
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
        self.sigmas = torch.from_numpy(sigmas)
96

97
98
99
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = self.sigmas.max()

100
101
        # setable values
        self.num_inference_steps = None
102
103
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
104
        self.derivatives = []
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        self.is_scale_input_called = False

    def scale_model_input(
        self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
    ) -> torch.FloatTensor:
        """
        Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        if isinstance(timestep, torch.Tensor):
            timestep = timestep.to(self.timesteps.device)
        step_index = (self.timesteps == timestep).nonzero().item()
        sigma = self.sigmas[step_index]
        sample = sample / ((sigma**2 + 1) ** 0.5)
        self.is_scale_input_called = True
        return sample
127
128
129

    def get_lms_coefficient(self, order, t, current_order):
        """
130
131
132
133
134
135
        Compute a linear multistep coefficient.

        Args:
            order (TODO):
            t (TODO):
            current_order (TODO):
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        """

        def lms_derivative(tau):
            prod = 1.0
            for k in range(order):
                if current_order == k:
                    continue
                prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
            return prod

        integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]

        return integrated_coeff

150
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
151
152
153
154
155
156
        """
        Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
157
158
            device (`str` or `torch.device`, optional):
                the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
159
        """
160
161
        self.num_inference_steps = num_inference_steps

162
        timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
163
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
164
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
165
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
166
167
        self.sigmas = torch.from_numpy(sigmas).to(device=device)
        self.timesteps = torch.from_numpy(timesteps).to(device=device)
168
169
170
171
172

        self.derivatives = []

    def step(
        self,
173
        model_output: torch.FloatTensor,
174
        timestep: Union[float, torch.FloatTensor],
175
        sample: torch.FloatTensor,
176
        order: int = 4,
177
        return_dict: bool = True,
178
    ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
179
180
181
182
183
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
184
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
185
            timestep (`float`): current timestep in the diffusion chain.
186
            sample (`torch.FloatTensor`):
187
188
                current instance of sample being created by diffusion process.
            order: coefficient for multi-step inference.
189
            return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
190
191

        Returns:
192
193
194
            [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
            When returning a tuple, the first element is the sample tensor.
195
196

        """
197
198
199
200
201
202
203
204
        if not self.is_scale_input_called:
            warnings.warn(
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

        if isinstance(timestep, torch.Tensor):
            timestep = timestep.to(self.timesteps.device)
205
206
207
208
209
        if (
            isinstance(timestep, int)
            or isinstance(timestep, torch.IntTensor)
            or isinstance(timestep, torch.LongTensor)
        ):
210
211
            deprecate(
                "timestep as an index",
Patrick von Platen's avatar
Patrick von Platen committed
212
                "0.7.0",
213
214
215
216
                "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                " `LMSDiscreteScheduler.step()` will not be supported in future versions. Make sure to pass"
                " one of the `scheduler.timesteps` as a timestep.",
                standard_warn=False,
217
218
219
220
            )
            step_index = timestep
        else:
            step_index = (self.timesteps == timestep).nonzero().item()
221
        sigma = self.sigmas[step_index]
222
223
224
225
226
227
228
229
230
231
232

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
        pred_original_sample = sample - sigma * model_output

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma
        self.derivatives.append(derivative)
        if len(self.derivatives) > order:
            self.derivatives.pop(0)

        # 3. Compute linear multistep coefficients
233
234
        order = min(step_index + 1, order)
        lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
235
236
237
238
239
240

        # 4. Compute previous sample based on the derivatives path
        prev_sample = sample + sum(
            coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
        )

241
242
243
        if not return_dict:
            return (prev_sample,)

244
        return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
245

246
247
    def add_noise(
        self,
248
249
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
250
        timesteps: torch.FloatTensor,
251
    ) -> torch.FloatTensor:
252
253
254
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
        self.timesteps = self.timesteps.to(original_samples.device)
255
        timesteps = timesteps.to(original_samples.device)
256
257
258

        schedule_timesteps = self.timesteps

259
        if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
260
261
            deprecate(
                "timesteps as indices",
Patrick von Platen's avatar
Patrick von Platen committed
262
                "0.7.0",
263
264
265
266
                "Passing integer indices  (e.g. from `enumerate(timesteps)`) as timesteps to"
                " `LMSDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
                " pass values from `scheduler.timesteps` as timesteps.",
                standard_warn=False,
267
268
269
270
            )
            step_indices = timesteps
        else:
            step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
271

272
        sigma = self.sigmas[step_indices].flatten()
273
274
275
276
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
277
278
279
280
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps