unet_unconditional.py 9.79 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
from typing import Dict, Union

3
4
5
6
7
import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
8
9
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
10
11


12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
    """
    The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
    model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
    num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
    rates at which
        attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
        downsampling, attention will be used.
    :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
    conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
    model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
    heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
    for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

36
37
    def __init__(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
38
39
40
41
        image_size=None,
        in_channels=None,
        out_channels=None,
        num_res_blocks=None,
42
        dropout=0,
Patrick von Platen's avatar
Patrick von Platen committed
43
        block_channels=(224, 448, 672, 896),
44
45
46
47
48
49
        down_blocks=(
            "UNetResDownBlock2D",
            "UNetResAttnDownBlock2D",
            "UNetResAttnDownBlock2D",
            "UNetResAttnDownBlock2D",
        ),
Patrick von Platen's avatar
Patrick von Platen committed
50
        downsample_padding=1,
51
52
53
54
55
        up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
        resnet_act_fn="silu",
        resnet_eps=1e-5,
        conv_resample=True,
        num_head_channels=32,
Patrick von Platen's avatar
Patrick von Platen committed
56
57
        flip_sin_to_cos=True,
        downscale_freq_shift=0,
58
59
60
61
        time_embedding_type="positional",
        mid_block_scale_factor=1,
        center_input_sample=False,
        resnet_num_groups=32,
Patrick von Platen's avatar
Patrick von Platen committed
62
        **kwargs,
63
64
    ):
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
65
66
67
68
69
70
71
72
73
        # remove automatically added kwargs
        for arg in self._automatically_saved_args:
            kwargs.pop(arg, None)

        if len(kwargs) > 0:
            raise ValueError(
                f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
            )

Patrick von Platen's avatar
Patrick von Platen committed
74
75
        # register all __init__ params to be accessible via `self.config.<...>`
        # should probably be automated down the road as this is pure boiler plate code
76
77
78
        self.register_to_config(
            image_size=image_size,
            in_channels=in_channels,
Patrick von Platen's avatar
Patrick von Platen committed
79
80
            block_channels=block_channels,
            downsample_padding=downsample_padding,
81
82
83
84
85
            out_channels=out_channels,
            num_res_blocks=num_res_blocks,
            down_blocks=down_blocks,
            up_blocks=up_blocks,
            dropout=dropout,
86
            resnet_eps=resnet_eps,
87
88
            conv_resample=conv_resample,
            num_head_channels=num_head_channels,
Patrick von Platen's avatar
Patrick von Platen committed
89
90
            flip_sin_to_cos=flip_sin_to_cos,
            downscale_freq_shift=downscale_freq_shift,
91
92
93
94
            time_embedding_type=time_embedding_type,
            mid_block_scale_factor=mid_block_scale_factor,
            resnet_num_groups=resnet_num_groups,
            center_input_sample=center_input_sample,
95
96
97
        )

        self.image_size = image_size
Patrick von Platen's avatar
Patrick von Platen committed
98
        time_embed_dim = block_channels[0] * 4
99

100
        # input
Patrick von Platen's avatar
Patrick von Platen committed
101
        self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
102

103
104
        # time
        if time_embedding_type == "fourier":
105
            self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=16)
106
107
108
109
110
111
            timestep_input_dim = 2 * block_channels[0]
        elif time_embedding_type == "positional":
            self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
            timestep_input_dim = block_channels[0]

        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
112
113

        self.downsample_blocks = nn.ModuleList([])
Patrick von Platen's avatar
Patrick von Platen committed
114
115
116
117
118
119
120
121
122
        self.mid = None
        self.upsample_blocks = nn.ModuleList([])

        # down
        output_channel = block_channels[0]
        for i, down_block_type in enumerate(down_blocks):
            input_channel = output_channel
            output_channel = block_channels[i]
            is_final_block = i == len(block_channels) - 1
123
124
125
126
127
128
129
130
131
132
133

            down_block = get_down_block(
                down_block_type,
                num_layers=num_res_blocks,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=resnet_eps,
                resnet_act_fn=resnet_act_fn,
                attn_num_head_channels=num_head_channels,
Patrick von Platen's avatar
Patrick von Platen committed
134
                downsample_padding=downsample_padding,
135
136
137
            )
            self.downsample_blocks.append(down_block)

Patrick von Platen's avatar
Patrick von Platen committed
138
        # mid
139
140
141
142
143
144
145
146
147
148
149
        self.mid = UNetMidBlock2D(
            in_channels=block_channels[-1],
            dropout=dropout,
            temb_channels=time_embed_dim,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            output_scale_factor=mid_block_scale_factor,
            resnet_time_scale_shift="default",
            attn_num_head_channels=num_head_channels,
            resnet_groups=resnet_num_groups,
        )
150

Patrick von Platen's avatar
Patrick von Platen committed
151
152
153
154
155
156
157
158
159
        # up
        reversed_block_channels = list(reversed(block_channels))
        output_channel = reversed_block_channels[0]
        for i, up_block_type in enumerate(up_blocks):
            prev_output_channel = output_channel
            output_channel = reversed_block_channels[i]
            input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]

            is_final_block = i == len(block_channels) - 1
160
161
162
163

            up_block = get_up_block(
                up_block_type,
                num_layers=num_res_blocks + 1,
Patrick von Platen's avatar
Patrick von Platen committed
164
165
166
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
167
168
169
170
171
172
173
                temb_channels=time_embed_dim,
                add_upsample=not is_final_block,
                resnet_eps=resnet_eps,
                resnet_act_fn=resnet_act_fn,
                attn_num_head_channels=num_head_channels,
            )
            self.upsample_blocks.append(up_block)
Patrick von Platen's avatar
Patrick von Platen committed
174
175
176
            prev_output_channel = output_channel

        # out
177
178
        num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
        self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
Patrick von Platen's avatar
Patrick von Platen committed
179
180
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
181

Patrick von Platen's avatar
Patrick von Platen committed
182
    def forward(
183
        self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
Patrick von Platen's avatar
Patrick von Platen committed
184
    ) -> Dict[str, torch.FloatTensor]:
Patrick von Platen's avatar
Patrick von Platen committed
185

186
        # 0. center input if necessary
187
188
189
190
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
191
        timesteps = timestep
192
193
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
194
195
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
196

197
        t_emb = self.time_steps(timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
198
        emb = self.time_embedding(t_emb)
199

200
201
        # 2. pre-process
        skip_sample = sample
202
203
        sample = self.conv_in(sample)

204
        # 3. down
205
206
        down_block_res_samples = (sample,)
        for downsample_block in self.downsample_blocks:
207
208
209
210
211
212
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = downsample_block(
                    hidden_states=sample, temb=emb, skip_sample=skip_sample
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
213
214
215

            down_block_res_samples += res_samples

216
        # 4. mid
217
        sample = self.mid(sample, emb)
218

219
220
        # 5. up
        skip_sample = None
221
222
223
224
        for upsample_block in self.upsample_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

225
226
227
228
229
230
            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
            else:
                sample = upsample_block(sample, res_samples, emb)

        # 6. post-process
Patrick von Platen's avatar
Patrick von Platen committed
231
232
233
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)
234

235
236
237
        if skip_sample is not None:
            sample += skip_sample

238
        if self.config.time_embedding_type == "fourier":
239
240
241
            timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
            sample = sample / timesteps

Patrick von Platen's avatar
Patrick von Platen committed
242
243
244
        output = {"sample": sample}

        return output