modeling_glide.py 10.7 KB
Newer Older
anton-l's avatar
anton-l committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.


anton-l's avatar
Style  
anton-l committed
17
18
import numpy as np
import torch
anton-l's avatar
anton-l committed
19
20

import tqdm
Patrick von Platen's avatar
Patrick von Platen committed
21
22
23
24
25
26
27
28
from diffusers import (
    ClassifierFreeGuidanceScheduler,
    CLIPTextModel,
    DiffusionPipeline,
    GlideDDIMScheduler,
    GLIDESuperResUNetModel,
    GLIDETextToImageUNetModel,
)
anton-l's avatar
Style  
anton-l committed
29
from transformers import GPT2Tokenizer
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45


def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res + torch.zeros(broadcast_shape, device=timesteps.device)
anton-l's avatar
anton-l committed
46
47
48


class GLIDE(DiffusionPipeline):
49
    def __init__(
anton-l's avatar
Style  
anton-l committed
50
        self,
51
52
        text_unet: GLIDETextToImageUNetModel,
        text_noise_scheduler: ClassifierFreeGuidanceScheduler,
anton-l's avatar
Style  
anton-l committed
53
54
        text_encoder: CLIPTextModel,
        tokenizer: GPT2Tokenizer,
55
        upscale_unet: GLIDESuperResUNetModel,
Patrick von Platen's avatar
Patrick von Platen committed
56
        upscale_noise_scheduler: GlideDDIMScheduler,
57
    ):
anton-l's avatar
anton-l committed
58
        super().__init__()
anton-l's avatar
Style  
anton-l committed
59
        self.register_modules(
Patrick von Platen's avatar
Patrick von Platen committed
60
61
62
63
64
65
            text_unet=text_unet,
            text_noise_scheduler=text_noise_scheduler,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            upscale_unet=upscale_unet,
            upscale_noise_scheduler=upscale_noise_scheduler,
anton-l's avatar
Style  
anton-l committed
66
        )
67

68
    def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
69
70
71
72
73
74
75
76
        """
        Compute the mean and variance of the diffusion posterior:

            q(x_{t-1} | x_t, x_0)

        """
        assert x_start.shape == x_t.shape
        posterior_mean = (
77
78
            _extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
            + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
79
        )
80
        posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
Patrick von Platen's avatar
Patrick von Platen committed
81
        posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
82
83
84
85
86
87
88
89
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

90
    def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        """
        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
        the initial x, x_0.

        :param model: the model, which takes a signal and a batch of timesteps
                      as input.
        :param x: the [N x C x ...] tensor at time t.
        :param t: a 1-D Tensor of timesteps.
        :param clip_denoised: if True, clip the denoised signal into [-1, 1].
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict with the following keys:
                 - 'mean': the model mean output.
                 - 'variance': the model variance output.
                 - 'log_variance': the log of 'variance'.
                 - 'pred_xstart': the prediction for x_0.
        """
anton-l's avatar
anton-l committed
108

109
110
        B, C = x.shape[:2]
        assert t.shape == (B,)
111
112
113
114
115
116
        if transformer_out is None:
            # super-res model
            model_output = model(x, t, low_res)
        else:
            # text2image model
            model_output = model(x, t, transformer_out)
117
118
119

        assert model_output.shape == (B, C * 2, *x.shape[2:])
        model_output, model_var_values = torch.split(model_output, C, dim=1)
120
121
        min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
        max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
122
123
124
125
126
        # The model_var_values is [-1, 1] for [min_var, max_var].
        frac = (model_var_values + 1) / 2
        model_log_variance = frac * max_log + (1 - frac) * min_log
        model_variance = torch.exp(model_log_variance)

127
        pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
128
129
        if clip_denoised:
            pred_xstart = pred_xstart.clamp(-1, 1)
130
        model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
131
132
133
134

        assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
        return model_mean, model_variance, model_log_variance, pred_xstart

135
    def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
136
137
        assert x_t.shape == eps.shape
        return (
138
139
            _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
140
141
        )

142
143
144
145
146
    def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
        return (
            _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
        ) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

anton-l's avatar
anton-l committed
147
    @torch.no_grad()
148
    def __call__(self, prompt, generator=None, torch_device=None):
anton-l's avatar
anton-l committed
149
150
        torch_device = "cuda" if torch.cuda.is_available() else "cpu"

151
        self.text_unet.to(torch_device)
152
        self.text_encoder.to(torch_device)
153
        self.upscale_unet.to(torch_device)
154

anton-l's avatar
anton-l committed
155
156
157
        # Create a classifier-free guidance sampling function
        guidance_scale = 3.0

158
        def text_model_fn(x_t, ts, transformer_out, **kwargs):
anton-l's avatar
anton-l committed
159
160
            half = x_t[: len(x_t) // 2]
            combined = torch.cat([half, half], dim=0)
161
            model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
anton-l's avatar
anton-l committed
162
163
164
165
166
167
            eps, rest = model_out[:, :3], model_out[:, 3:]
            cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
            half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
            eps = torch.cat([half_eps, half_eps], dim=0)
            return torch.cat([eps, rest], dim=1)

anton-l's avatar
anton-l committed
168
        # 1. Sample gaussian noise
anton-l's avatar
anton-l committed
169
        batch_size = 2  # second image is empty for classifier-free guidance
170
171
        image = self.text_noise_scheduler.sample_noise(
            (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
anton-l's avatar
Style  
anton-l committed
172
        )
173
174
175
176

        # 2. Encode tokens
        # an empty input is needed to guide the model away from (
        inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
anton-l's avatar
anton-l committed
177
178
179
        input_ids = inputs["input_ids"].to(torch_device)
        attention_mask = inputs["attention_mask"].to(torch_device)
        transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
180

181
182
        # 3. Run the text2image generation step
        num_timesteps = len(self.text_noise_scheduler)
183
184
        for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
            t = torch.tensor([i] * image.shape[0], device=torch_device)
185
186
187
188
            mean, variance, log_variance, pred_xstart = self.p_mean_variance(
                text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
            )
            noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
anton-l's avatar
Style  
anton-l committed
189
            nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1)))  # no noise when t == 0
190
            image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
anton-l's avatar
anton-l committed
191

192
193
194
195
196
197
198
199
200
201
        # 4. Run the upscaling step
        batch_size = 1
        image = image[:1]
        low_res = ((image + 1) * 127.5).round() / 127.5 - 1
        eta = 0.0

        # Tune this parameter to control the sharpness of 256x256 images.
        # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
        upsample_temp = 0.997

Patrick von Platen's avatar
Patrick von Platen committed
202
203
204
205
206
207
        image = (
            self.upscale_noise_scheduler.sample_noise(
                (batch_size, 3, 256, 256), device=torch_device, generator=generator
            )
            * upsample_temp
        )
208
209

        num_timesteps = len(self.upscale_noise_scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
210
211
212
        for t in tqdm.tqdm(
            reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
        ):
213
214
215
            # i) define coefficients for time step t
            clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
            clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
Patrick von Platen's avatar
Patrick von Platen committed
216
217
218
219
220
221
222
223
224
225
            image_coeff = (
                (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
                * torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
                / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
            )
            clipped_coeff = (
                torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
                * self.upscale_noise_scheduler.get_beta(t)
                / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
            )
226
227
228
229
230
231
232
233
234
235
236
237
238

            # ii) predict noise residual
            time_input = torch.tensor([t] * image.shape[0], device=torch_device)
            model_output = self.upscale_unet(image, time_input, low_res)
            noise_residual, pred_variance = torch.split(model_output, 3, dim=1)

            # iii) compute predicted image from residual
            # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
            pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
            pred_mean = torch.clamp(pred_mean, -1, 1)
            prev_image = clipped_coeff * pred_mean + image_coeff * image

            # iv) sample variance
Patrick von Platen's avatar
Patrick von Platen committed
239
240
241
            prev_variance = self.upscale_noise_scheduler.sample_variance(
                t, prev_image.shape, device=torch_device, generator=generator
            )
242
243
244
245
246

            # v) sample  x_{t-1} ~ N(prev_image, prev_variance)
            sampled_prev_image = prev_image + prev_variance
            image = sampled_prev_image

anton-l's avatar
anton-l committed
247
248
        image = image[0].permute(1, 2, 0)

anton-l's avatar
anton-l committed
249
        return image