scheduling_sde_ve.py 11.9 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
15
16
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

17
import warnings
18
19
from dataclasses import dataclass
from typing import Optional, Tuple, Union
20
21
22
23

import numpy as np
import torch

24
from ..configuration_utils import ConfigMixin, register_to_config
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput


@dataclass
class SdeVeOutput(BaseOutput):
    """
    Output class for the ScoreSdeVeScheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
    """

    prev_sample: torch.FloatTensor
    prev_sample_mean: torch.FloatTensor
44
45


Patrick von Platen's avatar
Patrick von Platen committed
46
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Nathan Lambert's avatar
Nathan Lambert committed
47
48
49
    """
    The variance exploding stochastic differential equation (SDE) scheduler.

50
51
52
53
54
55
56
57
58
59
60
61
62
    For more information, see the original paper: https://arxiv.org/abs/2011.13456

    Args:
        snr (`float`):
            coefficient weighting the step from the model_output sample (from the network) to the random noise.
        sigma_min (`float`):
                initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
                distribution of the data.
        sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
        sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to
        epsilon.
        correct_steps (`int`): number of correction steps performed on a produced sample.
        tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
Nathan Lambert's avatar
Nathan Lambert committed
63
64
    """

65
    @register_to_config
Nathan Lambert's avatar
Nathan Lambert committed
66
67
68
69
70
71
72
73
74
75
    def __init__(
        self,
        num_train_timesteps=2000,
        snr=0.15,
        sigma_min=0.01,
        sigma_max=1348,
        sampling_eps=1e-5,
        correct_steps=1,
        tensor_format="pt",
    ):
76
        # setable values
Patrick von Platen's avatar
Patrick von Platen committed
77
78
        self.timesteps = None

79
        self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
80
81

        self.tensor_format = tensor_format
Nathan Lambert's avatar
Nathan Lambert committed
82
83
        self.set_format(tensor_format=tensor_format)

84
    def set_timesteps(self, num_inference_steps, sampling_eps=None):
85
86
87
88
89
90
91
92
93
        """
        Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).

        """
94
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
Nathan Lambert's avatar
Nathan Lambert committed
95
96
        tensor_format = getattr(self, "tensor_format", "pt")
        if tensor_format == "np":
97
            self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
Nathan Lambert's avatar
Nathan Lambert committed
98
        elif tensor_format == "pt":
99
            self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
Nathan Lambert's avatar
Nathan Lambert committed
100
101
        else:
            raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
Patrick von Platen's avatar
Patrick von Platen committed
102

103
    def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        """
        Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.

        The sigmas control the weight of the `drift` and `diffusion` components of sample update.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sigma_min (`float`, optional):
                initial noise scale value (overrides value given at Scheduler instantiation).
            sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).

        """
118
119
120
        sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
        sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
Patrick von Platen's avatar
Patrick von Platen committed
121
        if self.timesteps is None:
122
            self.set_timesteps(num_inference_steps, sampling_eps)
Patrick von Platen's avatar
Patrick von Platen committed
123

Nathan Lambert's avatar
Nathan Lambert committed
124
125
        tensor_format = getattr(self, "tensor_format", "pt")
        if tensor_format == "np":
126
127
            self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
            self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
Nathan Lambert's avatar
Nathan Lambert committed
128
        elif tensor_format == "pt":
129
130
            self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
            self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
Nathan Lambert's avatar
Nathan Lambert committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        else:
            raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

    def get_adjacent_sigma(self, timesteps, t):
        tensor_format = getattr(self, "tensor_format", "pt")
        if tensor_format == "np":
            return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
        elif tensor_format == "pt":
            return torch.where(
                timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device)
            )

        raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

145
    def set_seed(self, seed):
146
147
148
149
150
        warnings.warn(
            "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
            " generator instead.",
            DeprecationWarning,
        )
151
152
153
154
155
156
157
158
159
160
161
162
163
        tensor_format = getattr(self, "tensor_format", "pt")
        if tensor_format == "np":
            np.random.seed(seed)
        elif tensor_format == "pt":
            torch.manual_seed(seed)
        else:
            raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

    def step_pred(
        self,
        model_output: Union[torch.FloatTensor, np.ndarray],
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
164
        generator: Optional[torch.Generator] = None,
165
        return_dict: bool = True,
166
        **kwargs,
167
    ) -> Union[SdeVeOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
168
        """
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
            prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.

Nathan Lambert's avatar
Nathan Lambert committed
183
        """
184
185
        if "seed" in kwargs and kwargs["seed"] is not None:
            self.set_seed(kwargs["seed"])
186

187
188
189
190
191
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

192
193
194
195
        timestep = timestep * torch.ones(
            sample.shape[0], device=sample.device
        )  # torch.repeat_interleave(timestep, sample.shape[0])
        timesteps = (timestep * (len(self.timesteps) - 1)).long()
Nathan Lambert's avatar
Nathan Lambert committed
196

197
198
199
        sigma = self.discrete_sigmas[timesteps].to(sample.device)
        adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep)
        drift = self.zeros_like(sample)
Nathan Lambert's avatar
Nathan Lambert committed
200
201
        diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

202
        # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
Nathan Lambert's avatar
Nathan Lambert committed
203
        # also equation 47 shows the analog from SDE models to ancestral sampling methods
204
        drift = drift - diffusion[:, None, None, None] ** 2 * model_output
Nathan Lambert's avatar
Nathan Lambert committed
205
206

        #  equation 6: sample noise for the diffusion term of
207
        noise = self.randn_like(sample, generator=generator)
208
        prev_sample_mean = sample - drift  # subtract because `dt` is a small negative timestep
Nathan Lambert's avatar
Nathan Lambert committed
209
        # TODO is the variable diffusion the correct scaling term for the noise?
210
        prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise  # add impact of diffusion field g
211

212
213
214
215
        if not return_dict:
            return (prev_sample, prev_sample_mean)

        return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
216
217
218
219
220

    def step_correct(
        self,
        model_output: Union[torch.FloatTensor, np.ndarray],
        sample: Union[torch.FloatTensor, np.ndarray],
221
        generator: Optional[torch.Generator] = None,
222
        return_dict: bool = True,
223
        **kwargs,
224
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
225
        """
226
227
        Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
        after making the prediction for the previous timestep.
228
229
230
231
232
233
234
235
236
237
238

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
            prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.

Nathan Lambert's avatar
Nathan Lambert committed
239
        """
240
241
        if "seed" in kwargs and kwargs["seed"] is not None:
            self.set_seed(kwargs["seed"])
242

243
244
245
246
247
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

Nathan Lambert's avatar
Nathan Lambert committed
248
249
        # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
        # sample noise for correction
250
        noise = self.randn_like(sample, generator=generator)
251

252
253
        # compute step size from the model_output, the noise, and the snr
        grad_norm = self.norm(model_output)
Nathan Lambert's avatar
Nathan Lambert committed
254
        noise_norm = self.norm(noise)
Patrick von Platen's avatar
Patrick von Platen committed
255
        step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
256
257
        step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
        # self.repeat_scalar(step_size, sample.shape[0])
258

259
260
261
        # compute corrected sample: model_output term and noise term
        prev_sample_mean = sample + step_size[:, None, None, None] * model_output
        prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
262

263
264
265
266
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Nathan Lambert's avatar
Nathan Lambert committed
267
268
269

    def __len__(self):
        return self.config.num_train_timesteps