modeling_glide.py 10.6 KB
Newer Older
anton-l's avatar
anton-l committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.


anton-l's avatar
Style  
anton-l committed
17
18
import numpy as np
import torch
anton-l's avatar
anton-l committed
19
20

import tqdm
21
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
anton-l's avatar
Style  
anton-l committed
22
from transformers import GPT2Tokenizer
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res + torch.zeros(broadcast_shape, device=timesteps.device)
anton-l's avatar
anton-l committed
39
40
41


class GLIDE(DiffusionPipeline):
42
    def __init__(
anton-l's avatar
Style  
anton-l committed
43
        self,
44
45
        text_unet: GLIDETextToImageUNetModel,
        text_noise_scheduler: ClassifierFreeGuidanceScheduler,
anton-l's avatar
Style  
anton-l committed
46
47
        text_encoder: CLIPTextModel,
        tokenizer: GPT2Tokenizer,
48
49
        upscale_unet: GLIDESuperResUNetModel,
        upscale_noise_scheduler: GlideDDIMScheduler
50
    ):
anton-l's avatar
anton-l committed
51
        super().__init__()
anton-l's avatar
Style  
anton-l committed
52
        self.register_modules(
53
54
            text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
            upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
anton-l's avatar
Style  
anton-l committed
55
        )
56

57
    def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
58
59
60
61
62
63
64
65
        """
        Compute the mean and variance of the diffusion posterior:

            q(x_{t-1} | x_t, x_0)

        """
        assert x_start.shape == x_t.shape
        posterior_mean = (
66
67
            _extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
            + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
68
        )
69
        posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
70
        posterior_log_variance_clipped = _extract_into_tensor(
71
            scheduler.posterior_log_variance_clipped, t, x_t.shape
72
73
74
75
76
77
78
79
80
        )
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

81
    def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        """
        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
        the initial x, x_0.

        :param model: the model, which takes a signal and a batch of timesteps
                      as input.
        :param x: the [N x C x ...] tensor at time t.
        :param t: a 1-D Tensor of timesteps.
        :param clip_denoised: if True, clip the denoised signal into [-1, 1].
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict with the following keys:
                 - 'mean': the model mean output.
                 - 'variance': the model variance output.
                 - 'log_variance': the log of 'variance'.
                 - 'pred_xstart': the prediction for x_0.
        """
anton-l's avatar
anton-l committed
99

100
101
        B, C = x.shape[:2]
        assert t.shape == (B,)
102
103
104
105
106
107
        if transformer_out is None:
            # super-res model
            model_output = model(x, t, low_res)
        else:
            # text2image model
            model_output = model(x, t, transformer_out)
108
109
110

        assert model_output.shape == (B, C * 2, *x.shape[2:])
        model_output, model_var_values = torch.split(model_output, C, dim=1)
111
112
        min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
        max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
113
114
115
116
117
        # The model_var_values is [-1, 1] for [min_var, max_var].
        frac = (model_var_values + 1) / 2
        model_log_variance = frac * max_log + (1 - frac) * min_log
        model_variance = torch.exp(model_log_variance)

118
        pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
119
120
        if clip_denoised:
            pred_xstart = pred_xstart.clamp(-1, 1)
121
        model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
122
123
124
125

        assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
        return model_mean, model_variance, model_log_variance, pred_xstart

126
    def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
127
128
        assert x_t.shape == eps.shape
        return (
129
130
            _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
131
132
        )

133
134
135
136
137
    def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
        return (
            _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
        ) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

anton-l's avatar
anton-l committed
138
    @torch.no_grad()
139
    def __call__(self, prompt, generator=None, torch_device=None):
anton-l's avatar
anton-l committed
140
141
        torch_device = "cuda" if torch.cuda.is_available() else "cpu"

142
        self.text_unet.to(torch_device)
143
        self.text_encoder.to(torch_device)
144
        self.upscale_unet.to(torch_device)
145

anton-l's avatar
anton-l committed
146
147
148
        # Create a classifier-free guidance sampling function
        guidance_scale = 3.0

149
        def text_model_fn(x_t, ts, transformer_out, **kwargs):
anton-l's avatar
anton-l committed
150
151
            half = x_t[: len(x_t) // 2]
            combined = torch.cat([half, half], dim=0)
152
            model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
anton-l's avatar
anton-l committed
153
154
155
156
157
158
            eps, rest = model_out[:, :3], model_out[:, 3:]
            cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
            half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
            eps = torch.cat([half_eps, half_eps], dim=0)
            return torch.cat([eps, rest], dim=1)

anton-l's avatar
anton-l committed
159
        # 1. Sample gaussian noise
anton-l's avatar
anton-l committed
160
        batch_size = 2  # second image is empty for classifier-free guidance
161
162
        image = self.text_noise_scheduler.sample_noise(
            (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
anton-l's avatar
Style  
anton-l committed
163
        )
164
165
166
167

        # 2. Encode tokens
        # an empty input is needed to guide the model away from (
        inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
anton-l's avatar
anton-l committed
168
169
170
        input_ids = inputs["input_ids"].to(torch_device)
        attention_mask = inputs["attention_mask"].to(torch_device)
        transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
171

172
173
        # 3. Run the text2image generation step
        num_timesteps = len(self.text_noise_scheduler)
174
175
        for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
            t = torch.tensor([i] * image.shape[0], device=torch_device)
176
177
178
179
            mean, variance, log_variance, pred_xstart = self.p_mean_variance(
                text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
            )
            noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
anton-l's avatar
Style  
anton-l committed
180
            nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1)))  # no noise when t == 0
181
            image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
anton-l's avatar
anton-l committed
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        # 4. Run the upscaling step
        batch_size = 1
        image = image[:1]
        low_res = ((image + 1) * 127.5).round() / 127.5 - 1
        eta = 0.0

        # Tune this parameter to control the sharpness of 256x256 images.
        # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
        upsample_temp = 0.997

        image = self.upscale_noise_scheduler.sample_noise(
            (batch_size, 3, 256, 256), device=torch_device, generator=generator
        ) * upsample_temp

        num_timesteps = len(self.upscale_noise_scheduler)
        for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
            # i) define coefficients for time step t
            clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
            clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
            image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
                self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
            clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
                t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))

            # ii) predict noise residual
            time_input = torch.tensor([t] * image.shape[0], device=torch_device)
            model_output = self.upscale_unet(image, time_input, low_res)
            noise_residual, pred_variance = torch.split(model_output, 3, dim=1)

            # iii) compute predicted image from residual
            # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
            pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
            pred_mean = torch.clamp(pred_mean, -1, 1)
            prev_image = clipped_coeff * pred_mean + image_coeff * image

            # iv) sample variance
            prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
                                                                 generator=generator)

            # v) sample  x_{t-1} ~ N(prev_image, prev_variance)
            sampled_prev_image = prev_image + prev_variance
            image = sampled_prev_image

anton-l's avatar
anton-l committed
226
227
        image = image[0].permute(1, 2, 0)

anton-l's avatar
anton-l committed
228
        return image