scheduling_ddpm.py 7.01 KB
Newer Older
1
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
improve  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

anton-l's avatar
anton-l committed
17
import math
18
from typing import Union
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
Patrick von Platen committed
20
import numpy as np
21
import torch
Patrick von Platen's avatar
improve  
Patrick von Platen committed
22

23
from ..configuration_utils import ConfigMixin, register_to_config
24
25
26
27
28
from .scheduling_utils import SchedulerMixin


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
29
30
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
31

Patrick von Platen's avatar
Patrick von Platen committed
32
33
34
    :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
    from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that part of the diffusion process.
35
36
37
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
38

39
40
41
42
43
44
45
46
47
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
48
49


Patrick von Platen's avatar
Patrick von Platen committed
50
class DDPMScheduler(SchedulerMixin, ConfigMixin):
51
    @register_to_config
Patrick von Platen's avatar
improve  
Patrick von Platen committed
52
53
    def __init__(
        self,
Nathan Lambert's avatar
Nathan Lambert committed
54
        num_train_timesteps=1000,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
55
56
57
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
58
59
        trained_betas=None,
        timestep_values=None,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
60
        variance_type="fixed_small",
Patrick von Platen's avatar
Patrick von Platen committed
61
        clip_sample=True,
Patrick von Platen's avatar
Patrick von Platen committed
62
        tensor_format="np",
Patrick von Platen's avatar
improve  
Patrick von Platen committed
63
64
    ):

65
66
67
        if trained_betas is not None:
            self.betas = np.asarray(trained_betas)
        elif beta_schedule == "linear":
Nathan Lambert's avatar
Nathan Lambert committed
68
            self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
anton-l's avatar
anton-l committed
69
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
70
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
71
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
72
73
74
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Patrick von Platen's avatar
Patrick von Platen committed
75
76
77
78
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
        self.one = np.array(1.0)

79
80
81
82
83
        # setable values
        self.num_inference_steps = None
        self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

        self.tensor_format = tensor_format
Patrick von Platen's avatar
Patrick von Platen committed
84
85
        self.set_format(tensor_format=tensor_format)

86
    def set_timesteps(self, num_inference_steps):
Patrick von Platen's avatar
Patrick von Platen committed
87
        num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
88
89
90
91
92
93
        self.num_inference_steps = num_inference_steps
        self.timesteps = np.arange(
            0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
        )[::-1].copy()
        self.set_format(tensor_format=self.tensor_format)

Patrick von Platen's avatar
Patrick von Platen committed
94
    def _get_variance(self, t, variance_type=None):
95
96
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
97

Kashif Rasul's avatar
Kashif Rasul committed
98
        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
99
        # and sample from it to get previous sample
Kashif Rasul's avatar
Kashif Rasul committed
100
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
101
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
Patrick von Platen's avatar
Patrick von Platen committed
102

103
104
105
        if variance_type is None:
            variance_type = self.config.variance_type

Patrick von Platen's avatar
Patrick von Platen committed
106
        # hacks - were probs added for training stability
107
        if variance_type == "fixed_small":
Patrick von Platen's avatar
Patrick von Platen committed
108
            variance = self.clip(variance, min_value=1e-20)
109
        # for rl-diffuser https://arxiv.org/abs/2205.09991
110
        elif variance_type == "fixed_small_log":
111
            variance = self.log(self.clip(variance, min_value=1e-20))
112
        elif variance_type == "fixed_large":
113
            variance = self.betas[t]
114
        elif variance_type == "fixed_large_log":
Patrick von Platen's avatar
Patrick von Platen committed
115
            # Glide max_log
116
            variance = self.log(self.betas[t])
Patrick von Platen's avatar
Patrick von Platen committed
117
118
119

        return variance

120
121
    def step(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
122
        model_output: Union[torch.FloatTensor, np.ndarray],
123
124
125
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
        predict_epsilon=True,
Patrick von Platen's avatar
Patrick von Platen committed
126
        generator=None,
127
128
    ):
        t = timestep
Patrick von Platen's avatar
Patrick von Platen committed
129
        # 1. compute alphas, betas
130
131
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
132
133
134
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

135
        # 2. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
136
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
137
        if predict_epsilon:
Patrick von Platen's avatar
Patrick von Platen committed
138
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
139
        else:
Patrick von Platen's avatar
Patrick von Platen committed
140
            pred_original_sample = model_output
Patrick von Platen's avatar
Patrick von Platen committed
141
142

        # 3. Clip "predicted x_0"
143
        if self.config.clip_sample:
144
            pred_original_sample = self.clip(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
145

146
        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
Patrick von Platen's avatar
Patrick von Platen committed
147
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
148
149
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
        current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
Patrick von Platen's avatar
Patrick von Platen committed
150

151
        # 5. Compute predicted previous sample µ_t
Patrick von Platen's avatar
Patrick von Platen committed
152
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
153
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
Patrick von Platen's avatar
Patrick von Platen committed
154

Patrick von Platen's avatar
Patrick von Platen committed
155
156
157
158
159
160
161
162
        # 6. Add noise
        variance = 0
        if t > 0:
            noise = torch.randn(model_output.shape, generator=generator).to(model_output.device)
            variance = self._get_variance(t).sqrt() * noise

        pred_prev_sample = pred_prev_sample + variance

163
        return {"prev_sample": pred_prev_sample}
Patrick von Platen's avatar
Patrick von Platen committed
164

165
    def add_noise(self, original_samples, noise, timesteps):
anton-l's avatar
anton-l committed
166
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
167
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
anton-l's avatar
anton-l committed
168
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
169
170
171
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
anton-l's avatar
anton-l committed
172
        return noisy_samples
anton-l's avatar
anton-l committed
173

Patrick von Platen's avatar
improve  
Patrick von Platen committed
174
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
175
        return self.config.num_train_timesteps