scheduling_ddpm.py 6.36 KB
Newer Older
1
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
improve  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

anton-l's avatar
anton-l committed
17
import math
Patrick von Platen's avatar
Patrick von Platen committed
18

Patrick von Platen's avatar
Patrick von Platen committed
19
import numpy as np
anton-l's avatar
anton-l committed
20
import torch
Patrick von Platen's avatar
improve  
Patrick von Platen committed
21

Patrick von Platen's avatar
Patrick von Platen committed
22
from ..configuration_utils import ConfigMixin
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from .scheduling_utils import SchedulerMixin


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
38

39
40
41
42
43
44
45
46
47
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
48
49


Patrick von Platen's avatar
Patrick von Platen committed
50
class DDPMScheduler(SchedulerMixin, ConfigMixin):
Patrick von Platen's avatar
improve  
Patrick von Platen committed
51
52
53
54
55
56
    def __init__(
        self,
        timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
57
58
        trained_betas=None,
        timestep_values=None,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
59
        variance_type="fixed_small",
Patrick von Platen's avatar
Patrick von Platen committed
60
        clip_sample=True,
Patrick von Platen's avatar
Patrick von Platen committed
61
        tensor_format="np",
Patrick von Platen's avatar
improve  
Patrick von Platen committed
62
63
    ):
        super().__init__()
64
        self.register_to_config(
Patrick von Platen's avatar
improve  
Patrick von Platen committed
65
66
67
68
            timesteps=timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule=beta_schedule,
69
70
            trained_betas=trained_betas,
            timestep_values=timestep_values,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
71
            variance_type=variance_type,
Patrick von Platen's avatar
Patrick von Platen committed
72
            clip_sample=clip_sample,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
73
74
        )

75
76
77
        if trained_betas is not None:
            self.betas = np.asarray(trained_betas)
        elif beta_schedule == "linear":
78
            self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
anton-l's avatar
anton-l committed
79
80
        elif beta_schedule == "squaredcos_cap_v2":
            # GLIDE cosine schedule
81
            self.betas = betas_for_alpha_bar(timesteps)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
82
83
84
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Patrick von Platen's avatar
Patrick von Platen committed
85
86
87
88
89
90
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
        self.one = np.array(1.0)

        self.set_format(tensor_format=tensor_format)

91
    def get_variance(self, t, variance_type=None):
92
93
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
94
95

        # For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
96
97
        # and sample from it to get previous sample
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
98
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
Patrick von Platen's avatar
Patrick von Platen committed
99

100
101
102
        if variance_type is None:
            variance_type = self.config.variance_type

Patrick von Platen's avatar
Patrick von Platen committed
103
        # hacks - were probs added for training stability
104
        if variance_type == "fixed_small":
Patrick von Platen's avatar
Patrick von Platen committed
105
            variance = self.clip(variance, min_value=1e-20)
106
        # for rl-diffuser https://arxiv.org/abs/2205.09991
107
        elif variance_type == "fixed_small_log":
108
            variance = self.log(self.clip(variance, min_value=1e-20))
109
        elif variance_type == "fixed_large":
110
            variance = self.betas[t]
111
112
113
        elif variance_type == "fixed_large_log":
            # GLIDE max_log
            variance = self.log(self.betas[t])
Patrick von Platen's avatar
Patrick von Platen committed
114
115
116

        return variance

117
    def step(self, residual, sample, t, predict_epsilon=True):
Patrick von Platen's avatar
Patrick von Platen committed
118
        # 1. compute alphas, betas
119
120
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
121
122
123
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

124
        # 2. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
125
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
126
127
128
129
        if predict_epsilon:
            pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
        else:
            pred_original_sample = residual
Patrick von Platen's avatar
Patrick von Platen committed
130
131

        # 3. Clip "predicted x_0"
132
        if self.config.clip_sample:
133
            pred_original_sample = self.clip(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
134

135
        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
Patrick von Platen's avatar
Patrick von Platen committed
136
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
137
138
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
        current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
Patrick von Platen's avatar
Patrick von Platen committed
139

140
        # 5. Compute predicted previous sample µ_t
Patrick von Platen's avatar
Patrick von Platen committed
141
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
142
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
Patrick von Platen's avatar
Patrick von Platen committed
143

144
        return pred_prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
145

anton-l's avatar
anton-l committed
146
147
148
149
150
151
152
153
154
155
156
157
    def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
        if timesteps.dim() != 1:
            raise ValueError("`timesteps` must be a 1D tensor")

        device = original_samples.device
        batch_size = original_samples.shape[0]
        timesteps = timesteps.reshape(batch_size, 1, 1, 1)

        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise
        return noisy_samples
anton-l's avatar
anton-l committed
158

Patrick von Platen's avatar
improve  
Patrick von Platen committed
159
    def __len__(self):
160
        return self.config.timesteps