unet_unconditional.py 8.54 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
from typing import Dict, Union

3
4
5
import torch
import torch.nn as nn

6
from ..configuration_utils import ConfigMixin, register_to_config
7
from ..modeling_utils import ModelMixin
8
9
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
10
11


12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
    """
    The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
    model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
    num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
    rates at which
        attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
        downsampling, attention will be used.
    :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
    conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
    model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
    heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
    for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

36
    @register_to_config
37
38
    def __init__(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
39
40
41
42
        image_size=None,
        in_channels=None,
        out_channels=None,
        num_res_blocks=None,
43
        dropout=0,
Patrick von Platen's avatar
Patrick von Platen committed
44
        block_channels=(224, 448, 672, 896),
45
46
47
48
49
50
        down_blocks=(
            "UNetResDownBlock2D",
            "UNetResAttnDownBlock2D",
            "UNetResAttnDownBlock2D",
            "UNetResAttnDownBlock2D",
        ),
Patrick von Platen's avatar
Patrick von Platen committed
51
        downsample_padding=1,
52
53
54
55
56
        up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
        resnet_act_fn="silu",
        resnet_eps=1e-5,
        conv_resample=True,
        num_head_channels=32,
Patrick von Platen's avatar
Patrick von Platen committed
57
58
        flip_sin_to_cos=True,
        downscale_freq_shift=0,
59
60
61
62
        time_embedding_type="positional",
        mid_block_scale_factor=1,
        center_input_sample=False,
        resnet_num_groups=32,
63
    ):
64
        super().__init__()
65
        self.image_size = image_size
Patrick von Platen's avatar
Patrick von Platen committed
66
        time_embed_dim = block_channels[0] * 4
67

68
        # input
Patrick von Platen's avatar
Patrick von Platen committed
69
        self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
70

71
72
        # time
        if time_embedding_type == "fourier":
73
            self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=16)
74
75
76
77
78
79
            timestep_input_dim = 2 * block_channels[0]
        elif time_embedding_type == "positional":
            self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
            timestep_input_dim = block_channels[0]

        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
80
81

        self.downsample_blocks = nn.ModuleList([])
Patrick von Platen's avatar
Patrick von Platen committed
82
83
84
85
86
87
88
89
90
        self.mid = None
        self.upsample_blocks = nn.ModuleList([])

        # down
        output_channel = block_channels[0]
        for i, down_block_type in enumerate(down_blocks):
            input_channel = output_channel
            output_channel = block_channels[i]
            is_final_block = i == len(block_channels) - 1
91
92
93
94
95
96
97
98
99
100
101

            down_block = get_down_block(
                down_block_type,
                num_layers=num_res_blocks,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=resnet_eps,
                resnet_act_fn=resnet_act_fn,
                attn_num_head_channels=num_head_channels,
Patrick von Platen's avatar
Patrick von Platen committed
102
                downsample_padding=downsample_padding,
103
104
105
            )
            self.downsample_blocks.append(down_block)

Patrick von Platen's avatar
Patrick von Platen committed
106
        # mid
107
108
109
110
111
112
113
114
115
116
117
        self.mid = UNetMidBlock2D(
            in_channels=block_channels[-1],
            dropout=dropout,
            temb_channels=time_embed_dim,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            output_scale_factor=mid_block_scale_factor,
            resnet_time_scale_shift="default",
            attn_num_head_channels=num_head_channels,
            resnet_groups=resnet_num_groups,
        )
118

Patrick von Platen's avatar
Patrick von Platen committed
119
120
121
122
123
124
125
126
127
        # up
        reversed_block_channels = list(reversed(block_channels))
        output_channel = reversed_block_channels[0]
        for i, up_block_type in enumerate(up_blocks):
            prev_output_channel = output_channel
            output_channel = reversed_block_channels[i]
            input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]

            is_final_block = i == len(block_channels) - 1
128
129
130
131

            up_block = get_up_block(
                up_block_type,
                num_layers=num_res_blocks + 1,
Patrick von Platen's avatar
Patrick von Platen committed
132
133
134
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
135
136
137
138
139
140
141
                temb_channels=time_embed_dim,
                add_upsample=not is_final_block,
                resnet_eps=resnet_eps,
                resnet_act_fn=resnet_act_fn,
                attn_num_head_channels=num_head_channels,
            )
            self.upsample_blocks.append(up_block)
Patrick von Platen's avatar
Patrick von Platen committed
142
143
144
            prev_output_channel = output_channel

        # out
145
146
        num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
        self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
Patrick von Platen's avatar
Patrick von Platen committed
147
148
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
149

Patrick von Platen's avatar
Patrick von Platen committed
150
    def forward(
151
        self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
Patrick von Platen's avatar
Patrick von Platen committed
152
    ) -> Dict[str, torch.FloatTensor]:
Patrick von Platen's avatar
Patrick von Platen committed
153

154
        # 0. center input if necessary
155
156
157
158
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
159
        timesteps = timestep
160
161
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
162
163
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
164

165
        t_emb = self.time_steps(timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
166
        emb = self.time_embedding(t_emb)
167

168
169
        # 2. pre-process
        skip_sample = sample
170
171
        sample = self.conv_in(sample)

172
        # 3. down
173
174
        down_block_res_samples = (sample,)
        for downsample_block in self.downsample_blocks:
175
176
177
178
179
180
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = downsample_block(
                    hidden_states=sample, temb=emb, skip_sample=skip_sample
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
181
182
183

            down_block_res_samples += res_samples

184
        # 4. mid
185
        sample = self.mid(sample, emb)
186

187
188
        # 5. up
        skip_sample = None
189
190
191
192
        for upsample_block in self.upsample_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

193
194
195
196
197
198
            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
            else:
                sample = upsample_block(sample, res_samples, emb)

        # 6. post-process
Patrick von Platen's avatar
Patrick von Platen committed
199
200
201
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)
202

203
204
205
        if skip_sample is not None:
            sample += skip_sample

206
        if self.config.time_embedding_type == "fourier":
207
208
209
            timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
            sample = sample / timesteps

Patrick von Platen's avatar
Patrick von Platen committed
210
211
212
        output = {"sample": sample}

        return output