scheduling_sde_ve.py 11.8 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
15
16
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

17
import math
18
19
from dataclasses import dataclass
from typing import Optional, Tuple, Union
20
21
22

import torch

23
from ..configuration_utils import ConfigMixin, register_to_config
Anton Lozhkov's avatar
Anton Lozhkov committed
24
from ..utils import BaseOutput
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from .scheduling_utils import SchedulerMixin, SchedulerOutput


@dataclass
class SdeVeOutput(BaseOutput):
    """
    Output class for the ScoreSdeVeScheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
    """

    prev_sample: torch.FloatTensor
    prev_sample_mean: torch.FloatTensor
43
44


Patrick von Platen's avatar
Patrick von Platen committed
45
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Nathan Lambert's avatar
Nathan Lambert committed
46
47
48
    """
    The variance exploding stochastic differential equation (SDE) scheduler.

49
50
    For more information, see the original paper: https://arxiv.org/abs/2011.13456

51
52
53
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
54
    [`~ConfigMixin.from_config`] functions.
55

56
    Args:
57
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
58
59
60
61
62
63
        snr (`float`):
            coefficient weighting the step from the model_output sample (from the network) to the random noise.
        sigma_min (`float`):
                initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
                distribution of the data.
        sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
Nathan Lambert's avatar
Nathan Lambert committed
64
        sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
65
66
        epsilon.
        correct_steps (`int`): number of correction steps performed on a produced sample.
Nathan Lambert's avatar
Nathan Lambert committed
67
68
    """

69
    @register_to_config
Nathan Lambert's avatar
Nathan Lambert committed
70
71
    def __init__(
        self,
72
73
74
75
76
77
        num_train_timesteps: int = 2000,
        snr: float = 0.15,
        sigma_min: float = 0.01,
        sigma_max: float = 1348.0,
        sampling_eps: float = 1e-5,
        correct_steps: int = 1,
Nathan Lambert's avatar
Nathan Lambert committed
78
    ):
79
80
81
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = sigma_max

82
        # setable values
Patrick von Platen's avatar
Patrick von Platen committed
83
84
        self.timesteps = None

85
        self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

101
102
103
    def set_timesteps(
        self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
    ):
104
105
106
107
108
109
110
111
112
        """
        Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).

        """
113
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
114

115
        self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
Patrick von Platen's avatar
Patrick von Platen committed
116

117
118
119
    def set_sigmas(
        self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
    ):
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        """
        Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.

        The sigmas control the weight of the `drift` and `diffusion` components of sample update.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sigma_min (`float`, optional):
                initial noise scale value (overrides value given at Scheduler instantiation).
            sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).

        """
134
135
136
        sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
        sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
Patrick von Platen's avatar
Patrick von Platen committed
137
        if self.timesteps is None:
138
            self.set_timesteps(num_inference_steps, sampling_eps)
Patrick von Platen's avatar
Patrick von Platen committed
139

140
141
142
        self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
        self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
        self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
Nathan Lambert's avatar
Nathan Lambert committed
143
144

    def get_adjacent_sigma(self, timesteps, t):
145
146
147
148
149
        return torch.where(
            timesteps == 0,
            torch.zeros_like(t.to(timesteps.device)),
            self.discrete_sigmas[timesteps - 1].to(timesteps.device),
        )
Nathan Lambert's avatar
Nathan Lambert committed
150

151
152
    def step_pred(
        self,
153
        model_output: torch.FloatTensor,
154
        timestep: int,
155
        sample: torch.FloatTensor,
156
        generator: Optional[torch.Generator] = None,
157
158
        return_dict: bool = True,
    ) -> Union[SdeVeOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
159
        """
160
161
162
163
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
164
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
165
            timestep (`int`): current discrete timestep in the diffusion chain.
166
            sample (`torch.FloatTensor`):
167
168
169
170
171
                current instance of sample being created by diffusion process.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
172
173
            [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
174

Nathan Lambert's avatar
Nathan Lambert committed
175
        """
176
177
178
179
180
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

181
182
183
184
        timestep = timestep * torch.ones(
            sample.shape[0], device=sample.device
        )  # torch.repeat_interleave(timestep, sample.shape[0])
        timesteps = (timestep * (len(self.timesteps) - 1)).long()
Nathan Lambert's avatar
Nathan Lambert committed
185

186
187
188
        # mps requires indices to be in the same device, so we use cpu as is the default with cuda
        timesteps = timesteps.to(self.discrete_sigmas.device)

189
        sigma = self.discrete_sigmas[timesteps].to(sample.device)
190
        adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
191
        drift = torch.zeros_like(sample)
Nathan Lambert's avatar
Nathan Lambert committed
192
193
        diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

194
        # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
Nathan Lambert's avatar
Nathan Lambert committed
195
        # also equation 47 shows the analog from SDE models to ancestral sampling methods
196
197
198
199
        diffusion = diffusion.flatten()
        while len(diffusion.shape) < len(sample.shape):
            diffusion = diffusion.unsqueeze(-1)
        drift = drift - diffusion**2 * model_output
Nathan Lambert's avatar
Nathan Lambert committed
200
201

        #  equation 6: sample noise for the diffusion term of
202
        noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
203
        prev_sample_mean = sample - drift  # subtract because `dt` is a small negative timestep
Nathan Lambert's avatar
Nathan Lambert committed
204
        # TODO is the variable diffusion the correct scaling term for the noise?
205
        prev_sample = prev_sample_mean + diffusion * noise  # add impact of diffusion field g
206

207
208
209
210
        if not return_dict:
            return (prev_sample, prev_sample_mean)

        return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
211
212
213

    def step_correct(
        self,
214
215
        model_output: torch.FloatTensor,
        sample: torch.FloatTensor,
216
        generator: Optional[torch.Generator] = None,
217
218
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
219
        """
220
221
        Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
        after making the prediction for the previous timestep.
222
223

        Args:
224
225
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
            sample (`torch.FloatTensor`):
226
227
228
229
230
                current instance of sample being created by diffusion process.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
231
232
            [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
233

Nathan Lambert's avatar
Nathan Lambert committed
234
        """
235
236
237
238
239
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

Nathan Lambert's avatar
Nathan Lambert committed
240
241
        # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
        # sample noise for correction
242
        noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
243

244
        # compute step size from the model_output, the noise, and the snr
245
246
        grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
Patrick von Platen's avatar
Patrick von Platen committed
247
        step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
248
249
        step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
        # self.repeat_scalar(step_size, sample.shape[0])
250

251
        # compute corrected sample: model_output term and noise term
252
253
254
255
256
        step_size = step_size.flatten()
        while len(step_size.shape) < len(sample.shape):
            step_size = step_size.unsqueeze(-1)
        prev_sample_mean = sample + step_size * model_output
        prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
257

258
259
260
261
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Nathan Lambert's avatar
Nathan Lambert committed
262
263
264

    def __len__(self):
        return self.config.num_train_timesteps