unet_unconditional.py 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import (
    UNetMidBlock2D,
    UNetResAttnDownBlock2D,
    UNetResAttnUpBlock2D,
    UNetResDownBlock2D,
    UNetResUpBlock2D,
)


class UNetUnconditionalModel(ModelMixin, ConfigMixin):
    """
    The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
    model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
    num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
    rates at which
        attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
        downsampling, attention will be used.
    :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
    conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
    model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
    heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
    for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def init_for_ldm(
        self,
        in_channels,
        model_channels,
        channel_mult,
        num_res_blocks,
        dropout,
        time_embed_dim,
        attention_resolutions,
        num_head_channels,
        num_heads,
        legacy,
        use_spatial_transformer,
        transformer_depth,
        context_dim,
        conv_resample,
        out_channels,
    ):
        # TODO(PVP) - delete after weight conversion

        class TimestepEmbedSequential(nn.Sequential):
            """
            A sequential module that passes timestep embeddings to the children that support it as an extra input.
            """

            pass

        # TODO(PVP) - delete after weight conversion
        def conv_nd(dims, *args, **kwargs):
            """
            Create a 1D, 2D, or 3D convolution module.
            """
            if dims == 1:
                return nn.Conv1d(*args, **kwargs)
            elif dims == 2:
                return nn.Conv2d(*args, **kwargs)
            elif dims == 3:
                return nn.Conv3d(*args, **kwargs)
            raise ValueError(f"unsupported dimensions: {dims}")

Patrick von Platen's avatar
Patrick von Platen committed
82
        dims = 2
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        self.input_blocks = nn.ModuleList(
            [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
        )

        self._feature_size = model_channels
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResnetBlock2D(
                        in_channels=ch,
                        out_channels=mult * model_channels,
                        dropout=dropout,
                        temb_channels=time_embed_dim,
                        eps=1e-5,
                        non_linearity="silu",
                        overwrite_for_ldm=True,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        # num_heads = 1
                        dim_head = num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                        ),
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            # num_heads = 1
            dim_head = num_head_channels

        if dim_head < 0:
            dim_head = None

        # TODO(Patrick) - delete after weight conversion
        # init to be able to overwrite `self.mid`
        self.middle_block = TimestepEmbedSequential(
            ResnetBlock2D(
                in_channels=ch,
                out_channels=None,
                dropout=dropout,
                temb_channels=time_embed_dim,
                eps=1e-5,
                non_linearity="silu",
                overwrite_for_ldm=True,
            ),
            AttentionBlock(
                ch,
                num_heads=num_heads,
                num_head_channels=dim_head,
            ),
            ResnetBlock2D(
                in_channels=ch,
                out_channels=None,
                dropout=dropout,
                temb_channels=time_embed_dim,
                eps=1e-5,
                non_linearity="silu",
                overwrite_for_ldm=True,
            ),
        )
        self._feature_size += ch

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                layers = [
                    ResnetBlock2D(
                        in_channels=ch + ich,
                        out_channels=model_channels * mult,
                        dropout=dropout,
                        temb_channels=time_embed_dim,
                        eps=1e-5,
                        non_linearity="silu",
                        overwrite_for_ldm=True,
                    ),
                ]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        # num_heads = 1
                        dim_head = num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            num_heads=-1,
                            num_head_channels=dim_head,
                        ),
                    )
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        # ================ SET WEIGHTS OF ALL WEIGHTS ==================
        for i, input_layer in enumerate(self.input_blocks[1:]):
            block_id = i // (num_res_blocks + 1)
            layer_in_block_id = i % (num_res_blocks + 1)

            if layer_in_block_id == 2:
                self.downsample_blocks[block_id].downsamplers[0].op.weight.data = input_layer[0].op.weight.data
                self.downsample_blocks[block_id].downsamplers[0].op.bias.data = input_layer[0].op.bias.data
            elif len(input_layer) > 1:
                self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
                self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
            else:
                self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])

        self.mid.resnets[0].set_weight(self.middle_block[0])
        self.mid.resnets[1].set_weight(self.middle_block[2])
        self.mid.attentions[0].set_weight(self.middle_block[1])

        for i, input_layer in enumerate(self.output_blocks):
            block_id = i // (num_res_blocks + 1)
            layer_in_block_id = i % (num_res_blocks + 1)

            if len(input_layer) > 2:
                self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
                self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
                self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
                self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
            elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
                self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
                self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
                self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
            elif len(input_layer) > 1:
                self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
                self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
            else:
                self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])

        self.conv_in.weight.data = self.input_blocks[0][0].weight.data
        self.conv_in.bias.data = self.input_blocks[0][0].bias.data

    def __init__(
        self,
        image_size,
        in_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
Patrick von Platen's avatar
Patrick von Platen committed
264
265
        resnet_input_channels=(224, 224, 448, 672),
        resnet_output_channels=(224, 448, 672, 896),
266
        conv_resample=True,
Patrick von Platen's avatar
Patrick von Platen committed
267
        num_head_channels=32,
268
269
270
271
272
273
274
    ):
        super().__init__()

        # register all __init__ params with self.register
        self.register_to_config(
            image_size=image_size,
            in_channels=in_channels,
Patrick von Platen's avatar
Patrick von Platen committed
275
276
            resnet_input_channels=resnet_input_channels,
            resnet_output_channels=resnet_output_channels,
277
278
279
280
281
282
283
284
            out_channels=out_channels,
            num_res_blocks=num_res_blocks,
            attention_resolutions=attention_resolutions,
            dropout=dropout,
            conv_resample=conv_resample,
            num_head_channels=num_head_channels,
        )

Patrick von Platen's avatar
Patrick von Platen committed
285
        # To delete - replace with config values
286
287
288
289
290
291
292
        self.image_size = image_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout

Patrick von Platen's avatar
Patrick von Platen committed
293
        time_embed_dim = resnet_input_channels[0] * 4
294
295

        # ======================== Input ===================
Patrick von Platen's avatar
Patrick von Platen committed
296
        self.conv_in = nn.Conv2d(in_channels, resnet_input_channels[0], kernel_size=3, padding=(1, 1))
297
298
299

        # ======================== Time ====================
        self.time_embed = nn.Sequential(
Patrick von Platen's avatar
Patrick von Platen committed
300
            nn.Linear(resnet_input_channels[0], time_embed_dim),
301
302
303
304
305
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

        # ======================== Down ====================
Patrick von Platen's avatar
Patrick von Platen committed
306
307
        input_channels = list(resnet_input_channels)
        output_channels = list(resnet_output_channels)
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

        ds_new = 1
        self.downsample_blocks = nn.ModuleList([])
        for i, (input_channel, output_channel) in enumerate(zip(input_channels, output_channels)):
            is_final_block = i == len(input_channels) - 1

            if ds_new in attention_resolutions:
                down_block = UNetResAttnDownBlock2D(
                    num_layers=num_res_blocks,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    temb_channels=time_embed_dim,
                    add_downsample=not is_final_block,
                    resnet_eps=1e-5,
                    resnet_act_fn="silu",
                    attn_num_head_channels=num_head_channels,
                )
            else:
                down_block = UNetResDownBlock2D(
                    num_layers=num_res_blocks,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    temb_channels=time_embed_dim,
                    add_downsample=not is_final_block,
                    resnet_eps=1e-5,
                    resnet_act_fn="silu",
                )

            self.downsample_blocks.append(down_block)

            ds_new *= 2

        ds_new = ds_new / 2

        # ======================== Mid ====================
        self.mid = UNetMidBlock2D(
            in_channels=output_channels[-1],
            dropout=dropout,
            temb_channels=time_embed_dim,
            resnet_eps=1e-5,
            resnet_act_fn="silu",
Patrick von Platen's avatar
Patrick von Platen committed
349
            resnet_time_scale_shift="default",
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
            attn_num_head_channels=num_head_channels,
        )

        self.upsample_blocks = nn.ModuleList([])
        for i, (input_channel, output_channel) in enumerate(zip(reversed(input_channels), reversed(output_channels))):
            is_final_block = i == len(input_channels) - 1

            if ds_new in attention_resolutions:
                up_block = UNetResAttnUpBlock2D(
                    num_layers=num_res_blocks + 1,
                    in_channels=output_channel,
                    next_channels=input_channel,
                    temb_channels=time_embed_dim,
                    add_upsample=not is_final_block,
                    resnet_eps=1e-5,
                    resnet_act_fn="silu",
                    attn_num_head_channels=num_head_channels,
                )
            else:
                up_block = UNetResUpBlock2D(
                    num_layers=num_res_blocks + 1,
                    in_channels=output_channel,
                    next_channels=input_channel,
                    temb_channels=time_embed_dim,
                    add_upsample=not is_final_block,
                    resnet_eps=1e-5,
                    resnet_act_fn="silu",
                )

            self.upsample_blocks.append(up_block)

            ds_new /= 2

        # ======================== Out ====================
        self.out = nn.Sequential(
            nn.GroupNorm(num_channels=output_channels[0], num_groups=32, eps=1e-5),
            nn.SiLU(),
Patrick von Platen's avatar
Patrick von Platen committed
387
            nn.Conv2d(resnet_input_channels[0], out_channels, 3, padding=1),
388
389
390
        )

        # =========== TO DELETE AFTER CONVERSION ==========
Patrick von Platen's avatar
Patrick von Platen committed
391
392
393
394
395
396
        transformer_depth = 1
        context_dim = None
        legacy = True
        num_heads = -1
        model_channels = resnet_input_channels[0]
        channel_mult = tuple([x // model_channels for x in resnet_output_channels])
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        self.init_for_ldm(
            in_channels,
            model_channels,
            channel_mult,
            num_res_blocks,
            dropout,
            time_embed_dim,
            attention_resolutions,
            num_head_channels,
            num_heads,
            legacy,
            False,
            transformer_depth,
            context_dim,
            conv_resample,
            out_channels,
        )

    def forward(self, sample, timesteps=None):
        # 1. time step embeddings
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
419
420
421
        t_emb = get_timestep_embedding(
            timesteps, self.config.resnet_input_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
        )
422
423
424
        emb = self.time_embed(t_emb)

        # 2. pre-process sample
Patrick von Platen's avatar
Patrick von Platen committed
425
        #        sample = sample.type(self.dtype_)
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        sample = self.conv_in(sample)

        # 3. down blocks
        down_block_res_samples = (sample,)
        for downsample_block in self.downsample_blocks:
            sample, res_samples = downsample_block(sample, emb)

            # append to tuple
            down_block_res_samples += res_samples

        # 4. mid block
        sample = self.mid(sample, emb)

        # 5. up blocks
        for upsample_block in self.upsample_blocks:

            # pop from tuple
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            sample = upsample_block(sample, res_samples, emb)

        # 6. post-process sample
        sample = self.out(sample)

        return sample