"vscode:/vscode.git/clone" did not exist on "137e75daa1d337b35a7ddc268f9d9e22de063530"
unet_new.py 4.64 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.
from torch import nn

17
from .attention import AttentionBlock, LinearAttention, SpatialTransformer
Patrick von Platen's avatar
Patrick von Platen committed
18
19
20
21
22
23
24
25
from .resnet import ResnetBlock2D


class UNetMidBlock2D(nn.Module):
    def __init__(
        self,
        in_channels: int,
        temb_channels: int,
26
        dropout: float = 0.0,
27
        num_blocks: int = 1,
Patrick von Platen's avatar
Patrick von Platen committed
28
29
30
31
        resnet_eps: float = 1e-6,
        resnet_time_scale_shift: str = "default",
        resnet_act_fn: str = "swish",
        resnet_groups: int = 32,
32
        resnet_pre_norm: bool = True,
Patrick von Platen's avatar
Patrick von Platen committed
33
34
35
36
37
38
39
40
41
42
43
44
        attention_layer_type: str = "self",
        attn_num_heads=1,
        attn_num_head_channels=None,
        attn_encoder_channels=None,
        attn_dim_head=None,
        attn_depth=None,
        output_scale_factor=1.0,
        overwrite_qkv=False,
        overwrite_unet=False,
    ):
        super().__init__()

45
46
47
48
49
50
51
52
53
54
55
56
        # there is always at least one resnet
        resnets = [
            ResnetBlock2D(
                in_channels=in_channels,
                out_channels=in_channels,
                temb_channels=temb_channels,
                groups=resnet_groups,
                dropout=dropout,
                time_embedding_norm=resnet_time_scale_shift,
                non_linearity=resnet_act_fn,
                output_scale_factor=output_scale_factor,
                pre_norm=resnet_pre_norm,
Patrick von Platen's avatar
Patrick von Platen committed
57
            )
58
59
        ]
        attentions = []
Patrick von Platen's avatar
Patrick von Platen committed
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        for _ in range(num_blocks):
            if attention_layer_type == "self":
                attentions.append(
                    AttentionBlock(
                        in_channels,
                        num_heads=attn_num_heads,
                        num_head_channels=attn_num_head_channels,
                        encoder_channels=attn_encoder_channels,
                        overwrite_qkv=overwrite_qkv,
                        rescale_output_factor=output_scale_factor,
                    )
                )
            elif attention_layer_type == "spatial":
                attentions.append(
                    SpatialTransformer(
                        in_channels,
                        attn_num_heads,
                        attn_num_head_channels,
                        depth=attn_depth,
                        context_dim=attn_encoder_channels,
                    )
                )
            elif attention_layer_type == "linear":
                attentions.append(LinearAttention(in_channels))
Patrick von Platen's avatar
Patrick von Platen committed
85

86
87
88
89
90
91
92
93
94
95
96
97
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,
                    out_channels=in_channels,
                    temb_channels=temb_channels,
                    groups=resnet_groups,
                    dropout=dropout,
                    time_embedding_norm=resnet_time_scale_shift,
                    non_linearity=resnet_act_fn,
                    output_scale_factor=output_scale_factor,
                    pre_norm=resnet_pre_norm,
                )
Patrick von Platen's avatar
Patrick von Platen committed
98
99
            )

100
101
102
        self.attentions = nn.ModuleList(attentions)
        self.resnets = nn.ModuleList(resnets)

103
    def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0):
104
        hidden_states = self.resnets[0](hidden_states, temb, mask=mask)
Patrick von Platen's avatar
Patrick von Platen committed
105

106
107
108
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            hidden_states = attn(hidden_states, encoder_states)
            hidden_states = resnet(hidden_states, temb, mask=mask)
Patrick von Platen's avatar
Patrick von Platen committed
109

110
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
111

112

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# class UNetResAttnDownBlock(nn.Module):
#    def __init__(
#        self,
#        in_channels: int,
#        out_channels: int,
#        temb_channels: int,
#        dropout: float = 0.0,
#        resnet_eps: float = 1e-6,
#        resnet_time_scale_shift: str = "default",
#        resnet_act_fn: str = "swish",
#        resnet_groups: int = 32,
#        resnet_pre_norm: bool = True,
#        attention_layer_type: str = "self",
#        attn_num_heads=1,
#        attn_num_head_channels=None,
#        attn_encoder_channels=None,
#        attn_dim_head=None,
#        attn_depth=None,
#        output_scale_factor=1.0,
#        overwrite_qkv=False,
#        overwrite_unet=False,
#    ):
#
#        self.resents =