scheduling_sde_vp.py 3.57 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

17
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
Patrick von Platen's avatar
Patrick von Platen committed
18

19
import math
20
import warnings
21

Patrick von Platen's avatar
Patrick von Platen committed
22
23
import torch

24
from ..configuration_utils import ConfigMixin, register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
25
26
27
28
from .scheduling_utils import SchedulerMixin


class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
29
30
31
    """
    The variance preserving stochastic differential equation (SDE) scheduler.

32
33
34
    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
Nathan Lambert's avatar
Nathan Lambert committed
35
    [`~ConfigMixin.from_config`] functions.
36

37
38
39
40
41
42
    For more information, see the original paper: https://arxiv.org/abs/2011.13456

    UNDER CONSTRUCTION

    """

43
    @register_to_config
44
45
46
47
48
49
50
    def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
        if "tensor_format" in kwargs:
            warnings.warn(
                "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
                "If you're running your code in PyTorch, you can safely remove this argument.",
                DeprecationWarning,
            )
Patrick von Platen's avatar
Patrick von Platen committed
51
52
53
54
55
56
57
        self.sigmas = None
        self.discrete_sigmas = None
        self.timesteps = None

    def set_timesteps(self, num_inference_steps):
        self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)

58
    def step_pred(self, score, x, t, generator=None):
59
60
61
62
63
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

Patrick von Platen's avatar
Patrick von Platen committed
64
        # TODO(Patrick) better comments + non-PyTorch
Nathan Lambert's avatar
Nathan Lambert committed
65
        # postprocess model score
Patrick von Platen's avatar
Patrick von Platen committed
66
67
68
69
        log_mean_coeff = (
            -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
        )
        std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
70
71
72
73
        std = std.flatten()
        while len(std.shape) < len(score.shape):
            std = std.unsqueeze(-1)
        score = -score / std
Patrick von Platen's avatar
Patrick von Platen committed
74

Patrick von Platen's avatar
Patrick von Platen committed
75
76
77
78
        # compute
        dt = -1.0 / len(self.timesteps)

        beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
79
80
81
82
83
        beta_t = beta_t.flatten()
        while len(beta_t.shape) < len(x.shape):
            beta_t = beta_t.unsqueeze(-1)
        drift = -0.5 * beta_t * x

Patrick von Platen's avatar
Patrick von Platen committed
84
        diffusion = torch.sqrt(beta_t)
85
        drift = drift - diffusion**2 * score
Patrick von Platen's avatar
Patrick von Platen committed
86
        x_mean = x + drift * dt
Patrick von Platen's avatar
Patrick von Platen committed
87
88

        # add noise
89
90
        noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device)
        x = x_mean + diffusion * math.sqrt(-dt) * noise
Patrick von Platen's avatar
Patrick von Platen committed
91
92

        return x, x_mean
Nathan Lambert's avatar
Nathan Lambert committed
93
94
95

    def __len__(self):
        return self.config.num_train_timesteps