"sgl-kernel/pyproject_cpu.toml" did not exist on "d353d08b4e8987f6e4a9c6e36c266c4dc00e7942"
scheduling_lms_discrete.py 12.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import warnings
15
from dataclasses import dataclass
16
from typing import Optional, Tuple, Union
17
18
19
20
21
22
23

import numpy as np
import torch

from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
24
from ..utils import BaseOutput, deprecate
25
26
27
28
from .scheduling_utils import SchedulerMixin


@dataclass
29
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class LMSDiscreteSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
45
46
47


class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
48
49
50
51
52
    """
    Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
    Katherine Crowson:
    https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181

53
54
55
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
56
    [`~ConfigMixin.from_config`] functions.
57

58
59
60
61
62
63
64
    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear` or `scaled_linear`.
Nathan Lambert's avatar
Nathan Lambert committed
65
66
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
67
68
69

    """

70
71
72
    @register_to_config
    def __init__(
        self,
73
74
75
76
77
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
78
    ):
79
        if trained_betas is not None:
80
            self.betas = torch.from_numpy(trained_betas)
81
        elif beta_schedule == "linear":
82
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
83
84
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
85
86
87
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
88
89
90
91
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
92
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
93

94
95
96
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
        self.sigmas = torch.from_numpy(sigmas)
97

98
99
100
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = self.sigmas.max()

101
102
        # setable values
        self.num_inference_steps = None
103
104
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
105
        self.derivatives = []
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        self.is_scale_input_called = False

    def scale_model_input(
        self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
    ) -> torch.FloatTensor:
        """
        Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        if isinstance(timestep, torch.Tensor):
            timestep = timestep.to(self.timesteps.device)
        step_index = (self.timesteps == timestep).nonzero().item()
        sigma = self.sigmas[step_index]
        sample = sample / ((sigma**2 + 1) ** 0.5)
        self.is_scale_input_called = True
        return sample
128
129
130

    def get_lms_coefficient(self, order, t, current_order):
        """
131
132
133
134
135
136
        Compute a linear multistep coefficient.

        Args:
            order (TODO):
            t (TODO):
            current_order (TODO):
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        """

        def lms_derivative(tau):
            prod = 1.0
            for k in range(order):
                if current_order == k:
                    continue
                prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
            return prod

        integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]

        return integrated_coeff

151
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
152
153
154
155
156
157
        """
        Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
158
159
            device (`str` or `torch.device`, optional):
                the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
160
        """
161
162
        self.num_inference_steps = num_inference_steps

163
        timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
164
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
165
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
166
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
167
168
        self.sigmas = torch.from_numpy(sigmas).to(device=device)
        self.timesteps = torch.from_numpy(timesteps).to(device=device)
169
170
171
172
173

        self.derivatives = []

    def step(
        self,
174
        model_output: torch.FloatTensor,
175
        timestep: Union[float, torch.FloatTensor],
176
        sample: torch.FloatTensor,
177
        order: int = 4,
178
        return_dict: bool = True,
179
    ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
180
181
182
183
184
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
185
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
186
            timestep (`float`): current timestep in the diffusion chain.
187
            sample (`torch.FloatTensor`):
188
189
                current instance of sample being created by diffusion process.
            order: coefficient for multi-step inference.
190
            return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
191
192

        Returns:
193
194
195
            [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
            When returning a tuple, the first element is the sample tensor.
196
197

        """
198
199
200
201
202
203
204
205
        if not self.is_scale_input_called:
            warnings.warn(
                "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
                "See `StableDiffusionPipeline` for a usage example."
            )

        if isinstance(timestep, torch.Tensor):
            timestep = timestep.to(self.timesteps.device)
206
207
208
209
210
        if (
            isinstance(timestep, int)
            or isinstance(timestep, torch.IntTensor)
            or isinstance(timestep, torch.LongTensor)
        ):
211
212
            deprecate(
                "timestep as an index",
213
                "0.8.0",
214
215
216
217
                "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                " `LMSDiscreteScheduler.step()` will not be supported in future versions. Make sure to pass"
                " one of the `scheduler.timesteps` as a timestep.",
                standard_warn=False,
218
219
220
221
            )
            step_index = timestep
        else:
            step_index = (self.timesteps == timestep).nonzero().item()
222
        sigma = self.sigmas[step_index]
223
224
225
226
227
228
229
230
231
232
233

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
        pred_original_sample = sample - sigma * model_output

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma
        self.derivatives.append(derivative)
        if len(self.derivatives) > order:
            self.derivatives.pop(0)

        # 3. Compute linear multistep coefficients
234
235
        order = min(step_index + 1, order)
        lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
236
237
238
239
240
241

        # 4. Compute previous sample based on the derivatives path
        prev_sample = sample + sum(
            coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
        )

242
243
244
        if not return_dict:
            return (prev_sample,)

245
        return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
246

247
248
    def add_noise(
        self,
249
250
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
251
        timesteps: torch.FloatTensor,
252
    ) -> torch.FloatTensor:
253
254
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
255
256
257
258
259
260
261
        if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
            # mps does not support float64
            self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
            timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
        else:
            self.timesteps = self.timesteps.to(original_samples.device)
            timesteps = timesteps.to(original_samples.device)
262
263
264

        schedule_timesteps = self.timesteps

265
        if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
266
267
            deprecate(
                "timesteps as indices",
268
                "0.8.0",
269
270
271
272
                "Passing integer indices  (e.g. from `enumerate(timesteps)`) as timesteps to"
                " `LMSDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
                " pass values from `scheduler.timesteps` as timesteps.",
                standard_warn=False,
273
274
275
276
            )
            step_indices = timesteps
        else:
            step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
277

278
        sigma = self.sigmas[step_indices].flatten()
279
280
281
282
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
283
284
285
286
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps