"scripts/vscode:/vscode.git/clone" did not exist on "4d9f82016e00cbe041d5c35925dd08ff51db922d"
modeling_ddpm.py 2.78 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.


from diffusers import DiffusionPipeline
Patrick von Platen's avatar
up  
Patrick von Platen committed
18
19
import tqdm
import torch
Patrick von Platen's avatar
Patrick von Platen committed
20
21
22


class DDPM(DiffusionPipeline):
23
24
25

    modeling_file = "modeling_ddpm.py"

Patrick von Platen's avatar
Patrick von Platen committed
26
    def __init__(self, unet, noise_scheduler):
Patrick von Platen's avatar
up  
Patrick von Platen committed
27
28
29
        super().__init__()
        self.register_modules(unet=unet, noise_scheduler=noise_scheduler)

Patrick von Platen's avatar
Patrick von Platen committed
30
31
32
    def __call__(self, batch_size=1, generator=None, torch_device=None):
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Patrick von Platen's avatar
up  
Patrick von Platen committed
33
34
35

        self.unet.to(torch_device)
        # 1. Sample gaussian noise
Patrick von Platen's avatar
up  
Patrick von Platen committed
36
        image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
Patrick von Platen's avatar
up  
Patrick von Platen committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
            # i) define coefficients for time step t
            clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
            clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
            image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
            clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))

            # ii) predict noise residual
            with torch.no_grad():
                noise_residual = self.unet(image, t)

            # iii) compute predicted image from residual
            # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
            pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
            pred_mean = torch.clamp(pred_mean, -1, 1)
            prev_image = clip_coeff * pred_mean + image_coeff * image

            # iv) sample variance
            prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)

            # v) sample  x_{t-1} ~ N(prev_image, prev_variance)
            sampled_prev_image = prev_image + prev_variance
            image = sampled_prev_image
Patrick von Platen's avatar
Patrick von Platen committed
60
61

        return image