scheduling_sde_ve.py 13.1 KB
Newer Older
1
# Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
15
16
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

17
import math
18
19
from dataclasses import dataclass
from typing import Optional, Tuple, Union
20
21
22

import torch

23
from ..configuration_utils import ConfigMixin, register_to_config
Dhruv Nair's avatar
Dhruv Nair committed
24
25
from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
26
27
28
29
30
31
from .scheduling_utils import SchedulerMixin, SchedulerOutput


@dataclass
class SdeVeOutput(BaseOutput):
    """
32
    Output class for the scheduler's `step` function output.
33
34
35

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
37
38
            denoising loop.
        prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39
            Mean averaged `prev_sample` over previous timesteps.
40
41
42
43
    """

    prev_sample: torch.FloatTensor
    prev_sample_mean: torch.FloatTensor
44
45


Patrick von Platen's avatar
Patrick von Platen committed
46
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Nathan Lambert's avatar
Nathan Lambert committed
47
    """
48
    `ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler.
Nathan Lambert's avatar
Nathan Lambert committed
49

50
51
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
52

53
    Args:
54
55
56
57
58
59
60
61
62
63
64
65
66
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        snr (`float`, defaults to 0.15):
            A coefficient weighting the step from the `model_output` sample (from the network) to the random noise.
        sigma_min (`float`, defaults to 0.01):
            The initial noise scale for the sigma sequence in the sampling procedure. The minimum sigma should mirror
            the distribution of the data.
        sigma_max (`float`, defaults to 1348.0):
            The maximum value used for the range of continuous timesteps passed into the model.
        sampling_eps (`float`, defaults to 1e-5):
            The end value of sampling where timesteps decrease progressively from 1 to epsilon.
        correct_steps (`int`, defaults to 1):
            The number of correction steps performed on a produced sample.
Nathan Lambert's avatar
Nathan Lambert committed
67
68
    """

69
70
    order = 1

71
    @register_to_config
Nathan Lambert's avatar
Nathan Lambert committed
72
73
    def __init__(
        self,
74
75
76
77
78
79
        num_train_timesteps: int = 2000,
        snr: float = 0.15,
        sigma_min: float = 0.01,
        sigma_max: float = 1348.0,
        sampling_eps: float = 1e-5,
        correct_steps: int = 1,
Nathan Lambert's avatar
Nathan Lambert committed
80
    ):
81
82
83
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = sigma_max

84
        # setable values
Patrick von Platen's avatar
Patrick von Platen committed
85
86
        self.timesteps = None

87
        self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
88

89
90
91
92
93
94
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
95
96
97
98
            sample (`torch.FloatTensor`):
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
99
100

        Returns:
101
102
            `torch.FloatTensor`:
                A scaled input sample.
103
104
105
        """
        return sample

106
107
108
    def set_timesteps(
        self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
    ):
109
        """
110
        Sets the continuous timesteps used for the diffusion chain (to be run before inference).
111
112
113

        Args:
            num_inference_steps (`int`):
114
115
116
117
118
                The number of diffusion steps used when generating samples with a pre-trained model.
            sampling_eps (`float`, *optional*):
                The final timestep value (overrides value given during scheduler instantiation).
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
119
120

        """
121
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
122

123
        self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
Patrick von Platen's avatar
Patrick von Platen committed
124

125
126
127
    def set_sigmas(
        self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
    ):
128
        """
129
130
        Sets the noise scales used for the diffusion chain (to be run before inference). The sigmas control the weight
        of the `drift` and `diffusion` components of the sample update.
131
132
133

        Args:
            num_inference_steps (`int`):
134
                The number of diffusion steps used when generating samples with a pre-trained model.
135
            sigma_min (`float`, optional):
136
                The initial noise scale value (overrides value given during scheduler instantiation).
137
            sigma_max (`float`, optional):
138
                The final noise scale value (overrides value given during scheduler instantiation).
139
            sampling_eps (`float`, optional):
140
                The final timestep value (overrides value given during scheduler instantiation).
141
142

        """
143
144
145
        sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
        sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
Patrick von Platen's avatar
Patrick von Platen committed
146
        if self.timesteps is None:
147
            self.set_timesteps(num_inference_steps, sampling_eps)
Patrick von Platen's avatar
Patrick von Platen committed
148

149
150
151
        self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
        self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
        self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
Nathan Lambert's avatar
Nathan Lambert committed
152
153

    def get_adjacent_sigma(self, timesteps, t):
154
155
156
157
158
        return torch.where(
            timesteps == 0,
            torch.zeros_like(t.to(timesteps.device)),
            self.discrete_sigmas[timesteps - 1].to(timesteps.device),
        )
Nathan Lambert's avatar
Nathan Lambert committed
159

160
161
    def step_pred(
        self,
162
        model_output: torch.FloatTensor,
163
        timestep: int,
164
        sample: torch.FloatTensor,
165
        generator: Optional[torch.Generator] = None,
166
167
        return_dict: bool = True,
    ) -> Union[SdeVeOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
168
        """
169
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
170
171
172
        process from the learned model outputs (most often the predicted noise).

        Args:
173
174
175
176
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
177
            sample (`torch.FloatTensor`):
178
179
180
181
182
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
183
184

        Returns:
185
186
187
            [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
                is returned where the first element is the sample tensor.
188

Nathan Lambert's avatar
Nathan Lambert committed
189
        """
190
191
192
193
194
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

195
196
197
198
        timestep = timestep * torch.ones(
            sample.shape[0], device=sample.device
        )  # torch.repeat_interleave(timestep, sample.shape[0])
        timesteps = (timestep * (len(self.timesteps) - 1)).long()
Nathan Lambert's avatar
Nathan Lambert committed
199

200
201
202
        # mps requires indices to be in the same device, so we use cpu as is the default with cuda
        timesteps = timesteps.to(self.discrete_sigmas.device)

203
        sigma = self.discrete_sigmas[timesteps].to(sample.device)
204
        adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
205
        drift = torch.zeros_like(sample)
Nathan Lambert's avatar
Nathan Lambert committed
206
207
        diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

208
        # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
Nathan Lambert's avatar
Nathan Lambert committed
209
        # also equation 47 shows the analog from SDE models to ancestral sampling methods
210
211
212
213
        diffusion = diffusion.flatten()
        while len(diffusion.shape) < len(sample.shape):
            diffusion = diffusion.unsqueeze(-1)
        drift = drift - diffusion**2 * model_output
Nathan Lambert's avatar
Nathan Lambert committed
214
215

        #  equation 6: sample noise for the diffusion term of
216
217
218
        noise = randn_tensor(
            sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype
        )
219
        prev_sample_mean = sample - drift  # subtract because `dt` is a small negative timestep
Nathan Lambert's avatar
Nathan Lambert committed
220
        # TODO is the variable diffusion the correct scaling term for the noise?
221
        prev_sample = prev_sample_mean + diffusion * noise  # add impact of diffusion field g
222

223
224
225
226
        if not return_dict:
            return (prev_sample, prev_sample_mean)

        return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
227
228
229

    def step_correct(
        self,
230
231
        model_output: torch.FloatTensor,
        sample: torch.FloatTensor,
232
        generator: Optional[torch.Generator] = None,
233
234
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
235
        """
236
237
        Correct the predicted sample based on the `model_output` of the network. This is often run repeatedly after
        making the prediction for the previous timestep.
238
239

        Args:
240
241
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
242
            sample (`torch.FloatTensor`):
243
244
245
246
247
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
248
249

        Returns:
250
251
252
            [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
                is returned where the first element is the sample tensor.
253

Nathan Lambert's avatar
Nathan Lambert committed
254
        """
255
256
257
258
259
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

Nathan Lambert's avatar
Nathan Lambert committed
260
261
        # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
        # sample noise for correction
262
        noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
263

264
        # compute step size from the model_output, the noise, and the snr
265
266
        grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
Patrick von Platen's avatar
Patrick von Platen committed
267
        step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
268
269
        step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
        # self.repeat_scalar(step_size, sample.shape[0])
270

271
        # compute corrected sample: model_output term and noise term
272
273
274
275
276
        step_size = step_size.flatten()
        while len(step_size.shape) < len(sample.shape):
            step_size = step_size.unsqueeze(-1)
        prev_sample_mean = sample + step_size * model_output
        prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
277

278
279
280
281
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Nathan Lambert's avatar
Nathan Lambert committed
282

283
284
285
286
287
288
289
290
291
    def add_noise(
        self,
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.FloatTensor,
    ) -> torch.FloatTensor:
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        timesteps = timesteps.to(original_samples.device)
        sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
Uranus's avatar
Uranus committed
292
293
294
295
296
        noise = (
            noise * sigmas[:, None, None, None]
            if noise is not None
            else torch.randn_like(original_samples) * sigmas[:, None, None, None]
        )
297
298
299
        noisy_samples = noise + original_samples
        return noisy_samples

Nathan Lambert's avatar
Nathan Lambert committed
300
301
    def __len__(self):
        return self.config.num_train_timesteps