scheduling_sde_ve.py 13.1 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
15
16
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

17
import math
18
19
from dataclasses import dataclass
from typing import Optional, Tuple, Union
20
21
22

import torch

23
from ..configuration_utils import ConfigMixin, register_to_config
24
from ..utils import BaseOutput, randn_tensor
25
26
27
28
29
30
from .scheduling_utils import SchedulerMixin, SchedulerOutput


@dataclass
class SdeVeOutput(BaseOutput):
    """
31
    Output class for the scheduler's `step` function output.
32
33
34

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
35
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
36
37
            denoising loop.
        prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
            Mean averaged `prev_sample` over previous timesteps.
39
40
41
42
    """

    prev_sample: torch.FloatTensor
    prev_sample_mean: torch.FloatTensor
43
44


Patrick von Platen's avatar
Patrick von Platen committed
45
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Nathan Lambert's avatar
Nathan Lambert committed
46
    """
47
    `ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler.
Nathan Lambert's avatar
Nathan Lambert committed
48

49
50
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
51

52
    Args:
53
54
55
56
57
58
59
60
61
62
63
64
65
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        snr (`float`, defaults to 0.15):
            A coefficient weighting the step from the `model_output` sample (from the network) to the random noise.
        sigma_min (`float`, defaults to 0.01):
            The initial noise scale for the sigma sequence in the sampling procedure. The minimum sigma should mirror
            the distribution of the data.
        sigma_max (`float`, defaults to 1348.0):
            The maximum value used for the range of continuous timesteps passed into the model.
        sampling_eps (`float`, defaults to 1e-5):
            The end value of sampling where timesteps decrease progressively from 1 to epsilon.
        correct_steps (`int`, defaults to 1):
            The number of correction steps performed on a produced sample.
Nathan Lambert's avatar
Nathan Lambert committed
66
67
    """

68
69
    order = 1

70
    @register_to_config
Nathan Lambert's avatar
Nathan Lambert committed
71
72
    def __init__(
        self,
73
74
75
76
77
78
        num_train_timesteps: int = 2000,
        snr: float = 0.15,
        sigma_min: float = 0.01,
        sigma_max: float = 1348.0,
        sampling_eps: float = 1e-5,
        correct_steps: int = 1,
Nathan Lambert's avatar
Nathan Lambert committed
79
    ):
80
81
82
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = sigma_max

83
        # setable values
Patrick von Platen's avatar
Patrick von Platen committed
84
85
        self.timesteps = None

86
        self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
87

88
89
90
91
92
93
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
94
95
96
97
            sample (`torch.FloatTensor`):
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
98
99

        Returns:
100
101
            `torch.FloatTensor`:
                A scaled input sample.
102
103
104
        """
        return sample

105
106
107
    def set_timesteps(
        self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
    ):
108
        """
109
        Sets the continuous timesteps used for the diffusion chain (to be run before inference).
110
111
112

        Args:
            num_inference_steps (`int`):
113
114
115
116
117
                The number of diffusion steps used when generating samples with a pre-trained model.
            sampling_eps (`float`, *optional*):
                The final timestep value (overrides value given during scheduler instantiation).
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
118
119

        """
120
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
121

122
        self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
Patrick von Platen's avatar
Patrick von Platen committed
123

124
125
126
    def set_sigmas(
        self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
    ):
127
        """
128
129
        Sets the noise scales used for the diffusion chain (to be run before inference). The sigmas control the weight
        of the `drift` and `diffusion` components of the sample update.
130
131
132

        Args:
            num_inference_steps (`int`):
133
                The number of diffusion steps used when generating samples with a pre-trained model.
134
            sigma_min (`float`, optional):
135
                The initial noise scale value (overrides value given during scheduler instantiation).
136
            sigma_max (`float`, optional):
137
                The final noise scale value (overrides value given during scheduler instantiation).
138
            sampling_eps (`float`, optional):
139
                The final timestep value (overrides value given during scheduler instantiation).
140
141

        """
142
143
144
        sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
        sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
Patrick von Platen's avatar
Patrick von Platen committed
145
        if self.timesteps is None:
146
            self.set_timesteps(num_inference_steps, sampling_eps)
Patrick von Platen's avatar
Patrick von Platen committed
147

148
149
150
        self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
        self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
        self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
Nathan Lambert's avatar
Nathan Lambert committed
151
152

    def get_adjacent_sigma(self, timesteps, t):
153
154
155
156
157
        return torch.where(
            timesteps == 0,
            torch.zeros_like(t.to(timesteps.device)),
            self.discrete_sigmas[timesteps - 1].to(timesteps.device),
        )
Nathan Lambert's avatar
Nathan Lambert committed
158

159
160
    def step_pred(
        self,
161
        model_output: torch.FloatTensor,
162
        timestep: int,
163
        sample: torch.FloatTensor,
164
        generator: Optional[torch.Generator] = None,
165
166
        return_dict: bool = True,
    ) -> Union[SdeVeOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
167
        """
168
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
169
170
171
        process from the learned model outputs (most often the predicted noise).

        Args:
172
173
174
175
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
176
            sample (`torch.FloatTensor`):
177
178
179
180
181
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
182
183

        Returns:
184
185
186
            [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
                is returned where the first element is the sample tensor.
187

Nathan Lambert's avatar
Nathan Lambert committed
188
        """
189
190
191
192
193
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

194
195
196
197
        timestep = timestep * torch.ones(
            sample.shape[0], device=sample.device
        )  # torch.repeat_interleave(timestep, sample.shape[0])
        timesteps = (timestep * (len(self.timesteps) - 1)).long()
Nathan Lambert's avatar
Nathan Lambert committed
198

199
200
201
        # mps requires indices to be in the same device, so we use cpu as is the default with cuda
        timesteps = timesteps.to(self.discrete_sigmas.device)

202
        sigma = self.discrete_sigmas[timesteps].to(sample.device)
203
        adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
204
        drift = torch.zeros_like(sample)
Nathan Lambert's avatar
Nathan Lambert committed
205
206
        diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

207
        # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
Nathan Lambert's avatar
Nathan Lambert committed
208
        # also equation 47 shows the analog from SDE models to ancestral sampling methods
209
210
211
212
        diffusion = diffusion.flatten()
        while len(diffusion.shape) < len(sample.shape):
            diffusion = diffusion.unsqueeze(-1)
        drift = drift - diffusion**2 * model_output
Nathan Lambert's avatar
Nathan Lambert committed
213
214

        #  equation 6: sample noise for the diffusion term of
215
216
217
        noise = randn_tensor(
            sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype
        )
218
        prev_sample_mean = sample - drift  # subtract because `dt` is a small negative timestep
Nathan Lambert's avatar
Nathan Lambert committed
219
        # TODO is the variable diffusion the correct scaling term for the noise?
220
        prev_sample = prev_sample_mean + diffusion * noise  # add impact of diffusion field g
221

222
223
224
225
        if not return_dict:
            return (prev_sample, prev_sample_mean)

        return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
226
227
228

    def step_correct(
        self,
229
230
        model_output: torch.FloatTensor,
        sample: torch.FloatTensor,
231
        generator: Optional[torch.Generator] = None,
232
233
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
234
        """
235
236
        Correct the predicted sample based on the `model_output` of the network. This is often run repeatedly after
        making the prediction for the previous timestep.
237
238

        Args:
239
240
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
241
            sample (`torch.FloatTensor`):
242
243
244
245
246
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
247
248

        Returns:
249
250
251
            [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
                is returned where the first element is the sample tensor.
252

Nathan Lambert's avatar
Nathan Lambert committed
253
        """
254
255
256
257
258
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

Nathan Lambert's avatar
Nathan Lambert committed
259
260
        # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
        # sample noise for correction
261
        noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
262

263
        # compute step size from the model_output, the noise, and the snr
264
265
        grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
Patrick von Platen's avatar
Patrick von Platen committed
266
        step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
267
268
        step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
        # self.repeat_scalar(step_size, sample.shape[0])
269

270
        # compute corrected sample: model_output term and noise term
271
272
273
274
275
        step_size = step_size.flatten()
        while len(step_size.shape) < len(sample.shape):
            step_size = step_size.unsqueeze(-1)
        prev_sample_mean = sample + step_size * model_output
        prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
276

277
278
279
280
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Nathan Lambert's avatar
Nathan Lambert committed
281

282
283
284
285
286
287
288
289
290
    def add_noise(
        self,
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.FloatTensor,
    ) -> torch.FloatTensor:
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        timesteps = timesteps.to(original_samples.device)
        sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
Uranus's avatar
Uranus committed
291
292
293
294
295
        noise = (
            noise * sigmas[:, None, None, None]
            if noise is not None
            else torch.randn_like(original_samples) * sigmas[:, None, None, None]
        )
296
297
298
        noisy_samples = noise + original_samples
        return noisy_samples

Nathan Lambert's avatar
Nathan Lambert committed
299
300
    def __len__(self):
        return self.config.num_train_timesteps