scheduling_lms_discrete.py 9.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import warnings
16
from dataclasses import dataclass
17
from typing import Optional, Tuple, Union
18
19
20
21
22
23
24

import numpy as np
import torch

from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin


@dataclass
class LMSDiscreteSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None
45
46
47


class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
48
49
50
51
52
    """
    Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
    Katherine Crowson:
    https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181

53
54
55
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
56
    [`~ConfigMixin.from_config`] functions.
57

58
59
60
61
62
63
64
    Args:
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
        beta_start (`float`): the starting `beta` value of inference.
        beta_end (`float`): the final `beta` value.
        beta_schedule (`str`):
            the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear` or `scaled_linear`.
Nathan Lambert's avatar
Nathan Lambert committed
65
66
        trained_betas (`np.ndarray`, optional):
            option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
67
68
69

    """

70
71
72
    @register_to_config
    def __init__(
        self,
73
74
75
76
77
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
78
        **kwargs,
79
    ):
80
81
82
83
84
85
86
        if "tensor_format" in kwargs:
            warnings.warn(
                "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
                "If you're running your code in PyTorch, you can safely remove this argument.",
                DeprecationWarning,
            )

87
        if trained_betas is not None:
88
            self.betas = torch.from_numpy(trained_betas)
89
        elif beta_schedule == "linear":
90
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
91
92
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
93
94
95
            self.betas = (
                torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
            )
96
97
98
99
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
100
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
101

102
103
104
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
        self.sigmas = torch.from_numpy(sigmas)
105
106
107

        # setable values
        self.num_inference_steps = None
108
109
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps)
110
111
112
113
        self.derivatives = []

    def get_lms_coefficient(self, order, t, current_order):
        """
114
115
116
117
118
119
        Compute a linear multistep coefficient.

        Args:
            order (TODO):
            t (TODO):
            current_order (TODO):
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        """

        def lms_derivative(tau):
            prod = 1.0
            for k in range(order):
                if current_order == k:
                    continue
                prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
            return prod

        integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]

        return integrated_coeff

134
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
135
136
137
138
139
140
        """
        Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
141
142
            device (`str` or `torch.device`, optional):
                the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
143
        """
144
145
        self.num_inference_steps = num_inference_steps

146
        timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
147
        sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
148
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
149
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
150
151
        self.sigmas = torch.from_numpy(sigmas).to(device=device)
        self.timesteps = torch.from_numpy(timesteps).to(device=device)
152
153
154
155
156

        self.derivatives = []

    def step(
        self,
157
        model_output: torch.FloatTensor,
158
        timestep: int,
159
        sample: torch.FloatTensor,
160
        order: int = 4,
161
        return_dict: bool = True,
162
    ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
163
164
165
166
167
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
168
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
169
            timestep (`int`): current discrete timestep in the diffusion chain.
170
            sample (`torch.FloatTensor`):
171
172
                current instance of sample being created by diffusion process.
            order: coefficient for multi-step inference.
173
            return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
174
175

        Returns:
176
177
178
            [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
            [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
            When returning a tuple, the first element is the sample tensor.
179
180

        """
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        sigma = self.sigmas[timestep]

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
        pred_original_sample = sample - sigma * model_output

        # 2. Convert to an ODE derivative
        derivative = (sample - pred_original_sample) / sigma
        self.derivatives.append(derivative)
        if len(self.derivatives) > order:
            self.derivatives.pop(0)

        # 3. Compute linear multistep coefficients
        order = min(timestep + 1, order)
        lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]

        # 4. Compute previous sample based on the derivatives path
        prev_sample = sample + sum(
            coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
        )

201
202
203
        if not return_dict:
            return (prev_sample,)

204
        return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
205

206
207
    def add_noise(
        self,
208
209
210
211
212
213
214
215
216
217
218
219
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
        sigmas = self.sigmas.to(original_samples.device)
        timesteps = timesteps.to(original_samples.device)

        sigma = sigmas[timesteps].flatten()
        while len(sigma.shape) < len(original_samples.shape):
            sigma = sigma.unsqueeze(-1)

        noisy_samples = original_samples + noise * sigma
220
221
222
223
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps