unet_new.py 4.35 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.
from torch import nn

from .attention import AttentionBlock, SpatialTransformer
from .resnet import ResnetBlock2D


class UNetMidBlock2D(nn.Module):
    def __init__(
        self,
        in_channels: int,
        temb_channels: int,
        dropout: float,
        resnet_eps: float = 1e-6,
        resnet_time_scale_shift: str = "default",
        resnet_act_fn: str = "swish",
        resnet_groups: int = 32,
        attention_layer_type: str = "self",
        attn_num_heads=1,
        attn_num_head_channels=None,
        attn_encoder_channels=None,
        attn_dim_head=None,
        attn_depth=None,
        output_scale_factor=1.0,
        overwrite_qkv=False,
        overwrite_unet=False,
    ):
        super().__init__()

        self.resnet_1 = ResnetBlock2D(
            in_channels=in_channels,
            out_channels=in_channels,
            temb_channels=temb_channels,
            groups=resnet_groups,
            dropout=dropout,
            time_embedding_norm=resnet_time_scale_shift,
            non_linearity=resnet_act_fn,
            output_scale_factor=output_scale_factor,
        )

        if attention_layer_type == "self":
            self.attn = AttentionBlock(
                in_channels,
                num_heads=attn_num_heads,
                num_head_channels=attn_num_head_channels,
                encoder_channels=attn_encoder_channels,
                overwrite_qkv=overwrite_qkv,
                rescale_output_factor=output_scale_factor,
            )
        elif attention_layer_type == "spatial":
            self.attn = (
                SpatialTransformer(
                    in_channels,
                    attn_num_heads,
                    attn_num_head_channels,
                    depth=attn_depth,
                    context_dim=attn_encoder_channels,
                ),
            )

        self.resnet_2 = ResnetBlock2D(
            in_channels=in_channels,
            out_channels=in_channels,
            temb_channels=temb_channels,
            groups=resnet_groups,
            dropout=dropout,
            time_embedding_norm=resnet_time_scale_shift,
            non_linearity=resnet_act_fn,
            output_scale_factor=output_scale_factor,
        )

        # TODO(Patrick) - delete all of the following code
        self.is_overwritten = False
        self.overwrite_unet = overwrite_unet
        if self.overwrite_unet:
            block_in = in_channels
            self.temb_ch = temb_channels
            self.block_1 = ResnetBlock2D(
                in_channels=block_in,
                out_channels=block_in,
                temb_channels=self.temb_ch,
                dropout=dropout,
                eps=resnet_eps,
            )
            self.attn_1 = AttentionBlock(
                block_in,
                num_heads=attn_num_heads,
                num_head_channels=attn_num_head_channels,
                encoder_channels=attn_encoder_channels,
                overwrite_qkv=True,
            )
            self.block_2 = ResnetBlock2D(
                in_channels=block_in,
                out_channels=block_in,
                temb_channels=self.temb_ch,
                dropout=dropout,
                eps=resnet_eps,
            )

    def forward(self, hidden_states, temb=None, encoder_states=None):
        if not self.is_overwritten and self.overwrite_unet:
            self.resnet_1 = self.block_1
            self.attn = self.attn_1
            self.resnet_2 = self.block_2
            self.is_overwritten = True

        hidden_states = self.resnet_1(hidden_states, temb)

        if encoder_states is None:
            hidden_states = self.attn(hidden_states)
        else:
            hidden_states = self.attn(hidden_states, encoder_states)

        hidden_states = self.resnet_2(hidden_states, temb)
        return hidden_states