scheduling_sde_ve.py 12.4 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
15
16
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

17
import math
18
19
from dataclasses import dataclass
from typing import Optional, Tuple, Union
20
21
22

import torch

23
from ..configuration_utils import ConfigMixin, register_to_config
24
from ..utils import BaseOutput, deprecate
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from .scheduling_utils import SchedulerMixin, SchedulerOutput


@dataclass
class SdeVeOutput(BaseOutput):
    """
    Output class for the ScoreSdeVeScheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
    """

    prev_sample: torch.FloatTensor
    prev_sample_mean: torch.FloatTensor
43
44


Patrick von Platen's avatar
Patrick von Platen committed
45
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Nathan Lambert's avatar
Nathan Lambert committed
46
47
48
    """
    The variance exploding stochastic differential equation (SDE) scheduler.

49
50
    For more information, see the original paper: https://arxiv.org/abs/2011.13456

51
52
53
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
54
    [`~ConfigMixin.from_config`] functions.
55

56
    Args:
57
        num_train_timesteps (`int`): number of diffusion steps used to train the model.
58
59
60
61
62
63
        snr (`float`):
            coefficient weighting the step from the model_output sample (from the network) to the random noise.
        sigma_min (`float`):
                initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
                distribution of the data.
        sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
Nathan Lambert's avatar
Nathan Lambert committed
64
        sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
65
66
        epsilon.
        correct_steps (`int`): number of correction steps performed on a produced sample.
Nathan Lambert's avatar
Nathan Lambert committed
67
68
    """

69
    @register_to_config
Nathan Lambert's avatar
Nathan Lambert committed
70
71
    def __init__(
        self,
72
73
74
75
76
77
        num_train_timesteps: int = 2000,
        snr: float = 0.15,
        sigma_min: float = 0.01,
        sigma_max: float = 1348.0,
        sampling_eps: float = 1e-5,
        correct_steps: int = 1,
78
        **kwargs,
Nathan Lambert's avatar
Nathan Lambert committed
79
    ):
80
81
82
83
84
85
        deprecate(
            "tensor_format",
            "0.5.0",
            "If you're running your code in PyTorch, you can safely remove this argument.",
            take_from=kwargs,
        )
86

87
88
89
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = sigma_max

90
        # setable values
Patrick von Platen's avatar
Patrick von Platen committed
91
92
        self.timesteps = None

93
        self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

109
110
111
    def set_timesteps(
        self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
    ):
112
113
114
115
116
117
118
119
120
        """
        Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).

        """
121
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
122

123
        self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
Patrick von Platen's avatar
Patrick von Platen committed
124

125
126
127
    def set_sigmas(
        self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
    ):
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        """
        Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.

        The sigmas control the weight of the `drift` and `diffusion` components of sample update.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sigma_min (`float`, optional):
                initial noise scale value (overrides value given at Scheduler instantiation).
            sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).

        """
142
143
144
        sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
        sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
Patrick von Platen's avatar
Patrick von Platen committed
145
        if self.timesteps is None:
146
            self.set_timesteps(num_inference_steps, sampling_eps)
Patrick von Platen's avatar
Patrick von Platen committed
147

148
149
150
        self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
        self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
        self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
Nathan Lambert's avatar
Nathan Lambert committed
151
152

    def get_adjacent_sigma(self, timesteps, t):
153
154
155
156
157
        return torch.where(
            timesteps == 0,
            torch.zeros_like(t.to(timesteps.device)),
            self.discrete_sigmas[timesteps - 1].to(timesteps.device),
        )
Nathan Lambert's avatar
Nathan Lambert committed
158

159
    def set_seed(self, seed):
160
        deprecate("set_seed", "0.5.0", "Please consider passing a generator instead.")
161
        torch.manual_seed(seed)
162
163
164

    def step_pred(
        self,
165
        model_output: torch.FloatTensor,
166
        timestep: int,
167
        sample: torch.FloatTensor,
168
        generator: Optional[torch.Generator] = None,
169
        return_dict: bool = True,
170
        **kwargs,
171
    ) -> Union[SdeVeOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
172
        """
173
174
175
176
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
177
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
178
            timestep (`int`): current discrete timestep in the diffusion chain.
179
            sample (`torch.FloatTensor`):
180
181
182
183
184
                current instance of sample being created by diffusion process.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
185
186
            [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
187

Nathan Lambert's avatar
Nathan Lambert committed
188
        """
189
190
        if "seed" in kwargs and kwargs["seed"] is not None:
            self.set_seed(kwargs["seed"])
191

192
193
194
195
196
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

197
198
199
200
        timestep = timestep * torch.ones(
            sample.shape[0], device=sample.device
        )  # torch.repeat_interleave(timestep, sample.shape[0])
        timesteps = (timestep * (len(self.timesteps) - 1)).long()
Nathan Lambert's avatar
Nathan Lambert committed
201

202
203
204
        # mps requires indices to be in the same device, so we use cpu as is the default with cuda
        timesteps = timesteps.to(self.discrete_sigmas.device)

205
        sigma = self.discrete_sigmas[timesteps].to(sample.device)
206
        adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
207
        drift = torch.zeros_like(sample)
Nathan Lambert's avatar
Nathan Lambert committed
208
209
        diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

210
        # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
Nathan Lambert's avatar
Nathan Lambert committed
211
        # also equation 47 shows the analog from SDE models to ancestral sampling methods
212
213
214
215
        diffusion = diffusion.flatten()
        while len(diffusion.shape) < len(sample.shape):
            diffusion = diffusion.unsqueeze(-1)
        drift = drift - diffusion**2 * model_output
Nathan Lambert's avatar
Nathan Lambert committed
216
217

        #  equation 6: sample noise for the diffusion term of
218
        noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
219
        prev_sample_mean = sample - drift  # subtract because `dt` is a small negative timestep
Nathan Lambert's avatar
Nathan Lambert committed
220
        # TODO is the variable diffusion the correct scaling term for the noise?
221
        prev_sample = prev_sample_mean + diffusion * noise  # add impact of diffusion field g
222

223
224
225
226
        if not return_dict:
            return (prev_sample, prev_sample_mean)

        return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
227
228
229

    def step_correct(
        self,
230
231
        model_output: torch.FloatTensor,
        sample: torch.FloatTensor,
232
        generator: Optional[torch.Generator] = None,
233
        return_dict: bool = True,
234
        **kwargs,
235
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
236
        """
237
238
        Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
        after making the prediction for the previous timestep.
239
240

        Args:
241
242
            model_output (`torch.FloatTensor`): direct output from learned diffusion model.
            sample (`torch.FloatTensor`):
243
244
245
246
247
                current instance of sample being created by diffusion process.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
248
249
            [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
250

Nathan Lambert's avatar
Nathan Lambert committed
251
        """
252
253
        if "seed" in kwargs and kwargs["seed"] is not None:
            self.set_seed(kwargs["seed"])
254

255
256
257
258
259
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

Nathan Lambert's avatar
Nathan Lambert committed
260
261
        # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
        # sample noise for correction
262
        noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
263

264
        # compute step size from the model_output, the noise, and the snr
265
266
        grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
Patrick von Platen's avatar
Patrick von Platen committed
267
        step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
268
269
        step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
        # self.repeat_scalar(step_size, sample.shape[0])
270

271
        # compute corrected sample: model_output term and noise term
272
273
274
275
276
        step_size = step_size.flatten()
        while len(step_size.shape) < len(sample.shape):
            step_size = step_size.unsqueeze(-1)
        prev_sample_mean = sample + step_size * model_output
        prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
277

278
279
280
281
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Nathan Lambert's avatar
Nathan Lambert committed
282
283
284

    def __len__(self):
        return self.config.num_train_timesteps