upscaling.py 3.34 KB
Newer Older
Fazzie's avatar
Fazzie committed
1
2
from functools import partial

3
4
5
import numpy as np
import torch
import torch.nn as nn
Fazzie's avatar
Fazzie committed
6
7
8
9
10
11
12
13
14
15
16
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ldm.util import default


class AbstractLowScaleModel(nn.Module):
    # for concatenating a downsampled image to the latent representation
    def __init__(self, noise_schedule_config=None):
        super(AbstractLowScaleModel, self).__init__()
        if noise_schedule_config is not None:
            self.register_schedule(**noise_schedule_config)

17
18
19
20
21
22
23
    def register_schedule(
        self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
    ):
        betas = make_beta_schedule(
            beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s
        )
        alphas = 1.0 - betas
Fazzie's avatar
Fazzie committed
24
        alphas_cumprod = np.cumprod(alphas, axis=0)
25
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
Fazzie's avatar
Fazzie committed
26

27
        (timesteps,) = betas.shape
Fazzie's avatar
Fazzie committed
28
29
30
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end
31
        assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep"
Fazzie's avatar
Fazzie committed
32
33
34

        to_torch = partial(torch.tensor, dtype=torch.float32)

35
36
37
        self.register_buffer("betas", to_torch(betas))
        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
        self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
Fazzie's avatar
Fazzie committed
38
39

        # calculations for diffusion q(x_t | x_{t-1}) and others
40
41
42
43
44
        self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)))
        self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)))
        self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)))
        self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)))
Fazzie's avatar
Fazzie committed
45
46
47

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
48
49
50
51
        return (
            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
Fazzie's avatar
Fazzie committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    def forward(self, x):
        return x, None

    def decode(self, x):
        return x


class SimpleImageConcat(AbstractLowScaleModel):
    # no noise level conditioning
    def __init__(self):
        super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
        self.max_noise_level = 0

    def forward(self, x):
        # fix to constant noise level
        return x, torch.zeros(x.shape[0], device=x.device).long()


class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
    def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
        super().__init__(noise_schedule_config=noise_schedule_config)
        self.max_noise_level = max_noise_level

    def forward(self, x, noise_level=None):
        if noise_level is None:
            noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
        else:
            assert isinstance(noise_level, torch.Tensor)
        z = self.q_sample(x, noise_level)
        return z, noise_level