unet_unconditional.py 9.47 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
from typing import Dict, Union

3
4
5
6
7
import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
8
9
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
10
11


12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
    """
    The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
    model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
    num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
    rates at which
        attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
        downsampling, attention will be used.
    :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
    conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
    model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
    heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
    for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

36
37
    def __init__(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
38
39
40
41
        image_size=None,
        in_channels=None,
        out_channels=None,
        num_res_blocks=None,
42
        dropout=0,
Patrick von Platen's avatar
Patrick von Platen committed
43
        block_channels=(224, 448, 672, 896),
44
45
46
47
48
49
        down_blocks=(
            "UNetResDownBlock2D",
            "UNetResAttnDownBlock2D",
            "UNetResAttnDownBlock2D",
            "UNetResAttnDownBlock2D",
        ),
Patrick von Platen's avatar
Patrick von Platen committed
50
        downsample_padding=1,
51
52
53
54
55
        up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
        resnet_act_fn="silu",
        resnet_eps=1e-5,
        conv_resample=True,
        num_head_channels=32,
Patrick von Platen's avatar
Patrick von Platen committed
56
57
        flip_sin_to_cos=True,
        downscale_freq_shift=0,
58
59
60
61
        time_embedding_type="positional",
        mid_block_scale_factor=1,
        center_input_sample=False,
        resnet_num_groups=32,
62
63
    ):
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
64
65
        # register all __init__ params to be accessible via `self.config.<...>`
        # should probably be automated down the road as this is pure boiler plate code
66
67
68
        self.register_to_config(
            image_size=image_size,
            in_channels=in_channels,
Patrick von Platen's avatar
Patrick von Platen committed
69
70
            block_channels=block_channels,
            downsample_padding=downsample_padding,
71
72
73
74
75
            out_channels=out_channels,
            num_res_blocks=num_res_blocks,
            down_blocks=down_blocks,
            up_blocks=up_blocks,
            dropout=dropout,
76
            resnet_eps=resnet_eps,
77
78
            conv_resample=conv_resample,
            num_head_channels=num_head_channels,
Patrick von Platen's avatar
Patrick von Platen committed
79
80
            flip_sin_to_cos=flip_sin_to_cos,
            downscale_freq_shift=downscale_freq_shift,
81
82
83
84
            time_embedding_type=time_embedding_type,
            mid_block_scale_factor=mid_block_scale_factor,
            resnet_num_groups=resnet_num_groups,
            center_input_sample=center_input_sample,
85
86
87
        )

        self.image_size = image_size
Patrick von Platen's avatar
Patrick von Platen committed
88
        time_embed_dim = block_channels[0] * 4
89

90
        # input
Patrick von Platen's avatar
Patrick von Platen committed
91
        self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
92

93
94
        # time
        if time_embedding_type == "fourier":
95
            self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=16)
96
97
98
99
100
101
            timestep_input_dim = 2 * block_channels[0]
        elif time_embedding_type == "positional":
            self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
            timestep_input_dim = block_channels[0]

        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
102
103

        self.downsample_blocks = nn.ModuleList([])
Patrick von Platen's avatar
Patrick von Platen committed
104
105
106
107
108
109
110
111
112
        self.mid = None
        self.upsample_blocks = nn.ModuleList([])

        # down
        output_channel = block_channels[0]
        for i, down_block_type in enumerate(down_blocks):
            input_channel = output_channel
            output_channel = block_channels[i]
            is_final_block = i == len(block_channels) - 1
113
114
115
116
117
118
119
120
121
122
123

            down_block = get_down_block(
                down_block_type,
                num_layers=num_res_blocks,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=resnet_eps,
                resnet_act_fn=resnet_act_fn,
                attn_num_head_channels=num_head_channels,
Patrick von Platen's avatar
Patrick von Platen committed
124
                downsample_padding=downsample_padding,
125
126
127
            )
            self.downsample_blocks.append(down_block)

Patrick von Platen's avatar
Patrick von Platen committed
128
        # mid
129
130
131
132
133
134
135
136
137
138
139
        self.mid = UNetMidBlock2D(
            in_channels=block_channels[-1],
            dropout=dropout,
            temb_channels=time_embed_dim,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            output_scale_factor=mid_block_scale_factor,
            resnet_time_scale_shift="default",
            attn_num_head_channels=num_head_channels,
            resnet_groups=resnet_num_groups,
        )
140

Patrick von Platen's avatar
Patrick von Platen committed
141
142
143
144
145
146
147
148
149
        # up
        reversed_block_channels = list(reversed(block_channels))
        output_channel = reversed_block_channels[0]
        for i, up_block_type in enumerate(up_blocks):
            prev_output_channel = output_channel
            output_channel = reversed_block_channels[i]
            input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]

            is_final_block = i == len(block_channels) - 1
150
151
152
153

            up_block = get_up_block(
                up_block_type,
                num_layers=num_res_blocks + 1,
Patrick von Platen's avatar
Patrick von Platen committed
154
155
156
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
157
158
159
160
161
162
163
                temb_channels=time_embed_dim,
                add_upsample=not is_final_block,
                resnet_eps=resnet_eps,
                resnet_act_fn=resnet_act_fn,
                attn_num_head_channels=num_head_channels,
            )
            self.upsample_blocks.append(up_block)
Patrick von Platen's avatar
Patrick von Platen committed
164
165
166
            prev_output_channel = output_channel

        # out
167
168
        num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
        self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
Patrick von Platen's avatar
Patrick von Platen committed
169
170
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
171

Patrick von Platen's avatar
Patrick von Platen committed
172
    def forward(
173
        self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
Patrick von Platen's avatar
Patrick von Platen committed
174
    ) -> Dict[str, torch.FloatTensor]:
Patrick von Platen's avatar
Patrick von Platen committed
175

176
        # 0. center input if necessary
177
178
179
180
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
181
        timesteps = timestep
182
183
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
184
185
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
186

187
        t_emb = self.time_steps(timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
188
        emb = self.time_embedding(t_emb)
189

190
191
        # 2. pre-process
        skip_sample = sample
192
193
        sample = self.conv_in(sample)

194
        # 3. down
195
196
        down_block_res_samples = (sample,)
        for downsample_block in self.downsample_blocks:
197
198
199
200
201
202
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = downsample_block(
                    hidden_states=sample, temb=emb, skip_sample=skip_sample
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
203
204
205

            down_block_res_samples += res_samples

206
        # 4. mid
207
        sample = self.mid(sample, emb)
208

209
210
        # 5. up
        skip_sample = None
211
212
213
214
        for upsample_block in self.upsample_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

215
216
217
218
219
220
            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
            else:
                sample = upsample_block(sample, res_samples, emb)

        # 6. post-process
Patrick von Platen's avatar
Patrick von Platen committed
221
222
223
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)
224

225
226
227
        if skip_sample is not None:
            sample += skip_sample

228
        if self.config.time_embedding_type == "fourier":
229
230
231
            timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
            sample = sample / timesteps

Patrick von Platen's avatar
Patrick von Platen committed
232
233
234
        output = {"sample": sample}

        return output