scheduling_sde_ve.py 3.19 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
15
16
17
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
18
19
20
21
22
23
24
25

import numpy as np
import torch

from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin


Patrick von Platen's avatar
Patrick von Platen committed
26
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Patrick von Platen's avatar
fix bug  
Patrick von Platen committed
27
    def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
28
29
30
31
32
33
34
35
        super().__init__()
        self.register_to_config(
            snr=snr,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
            sampling_eps=sampling_eps,
        )

Patrick von Platen's avatar
Patrick von Platen committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        self.sigmas = None
        self.discrete_sigmas = None
        self.timesteps = None

    def set_timesteps(self, num_inference_steps):
        self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)

    def set_sigmas(self, num_inference_steps):
        if self.timesteps is None:
            self.set_timesteps(num_inference_steps)

        self.discrete_sigmas = torch.exp(
            torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
        )
        self.sigmas = torch.tensor(
            [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
        )
53
54

    def step_pred(self, result, x, t):
Patrick von Platen's avatar
Patrick von Platen committed
55
        # TODO(Patrick) better comments + non-PyTorch
Patrick von Platen's avatar
Patrick von Platen committed
56
        t = t * torch.ones(x.shape[0], device=x.device)
Patrick von Platen's avatar
fix bug  
Patrick von Platen committed
57
        timestep = (t * (len(self.timesteps) - 1)).long()
58
59
60

        sigma = self.discrete_sigmas.to(t.device)[timestep]
        adjacent_sigma = torch.where(
Patrick von Platen's avatar
Patrick von Platen committed
61
            timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device)
62
63
64
65
66
67
68
69
70
71
72
73
        )
        f = torch.zeros_like(x)
        G = torch.sqrt(sigma**2 - adjacent_sigma**2)

        f = f - G[:, None, None, None] ** 2 * result

        z = torch.randn_like(x)
        x_mean = x - f
        x = x_mean + G[:, None, None, None] * z
        return x, x_mean

    def step_correct(self, result, x):
Patrick von Platen's avatar
Patrick von Platen committed
74
        # TODO(Patrick) better comments + non-PyTorch
75
76
77
        noise = torch.randn_like(x)
        grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
Patrick von Platen's avatar
Patrick von Platen committed
78
        step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
79
80
81
82
83
84
        step_size = step_size * torch.ones(x.shape[0], device=x.device)
        x_mean = x + step_size[:, None, None, None] * result

        x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise

        return x