attention2d.py 7.36 KB
Newer Older
1
2
3
4
5
6
7
import math

import torch
import torch.nn.functional as F
from torch import nn


Patrick von Platen's avatar
Patrick von Platen committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# unet_grad_tts.py
class LinearAttention(torch.nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super(LinearAttention, self).__init__()
        self.heads = heads
        self.dim_head = dim_head
        hidden_dim = dim_head * heads
        self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        #        q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
        q, k, v = (
            qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
            .permute(1, 0, 2, 3, 4, 5)
            .reshape(3, b, self.heads, self.dim_head, -1)
        )
        k = k.softmax(dim=-1)
        context = torch.einsum("bhdn,bhen->bhde", k, v)
        out = torch.einsum("bhde,bhdn->bhen", context, q)
        #        out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
        out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
        return self.to_out(out)

34
35

# unet_glide.py & unet_ldm.py
Patrick von Platen's avatar
Patrick von Platen committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.

    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    def __init__(
        self,
        channels,
        num_heads=1,
        num_head_channels=-1,
        use_checkpoint=False,
        encoder_channels=None,
51
        use_new_attention_order=False,  # TODO(Patrick) -> is never used, maybe delete?
Patrick von Platen's avatar
Patrick von Platen committed
52
        overwrite_qkv=False,
Patrick von Platen's avatar
Patrick von Platen committed
53
54
55
56
57
58
59
60
61
62
    ):
        super().__init__()
        self.channels = channels
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert (
                channels % num_head_channels == 0
            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = channels // num_head_channels
Patrick von Platen's avatar
Patrick von Platen committed
63

Patrick von Platen's avatar
Patrick von Platen committed
64
65
66
        self.use_checkpoint = use_checkpoint
        self.norm = normalization(channels, swish=0.0)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
Patrick von Platen's avatar
Patrick von Platen committed
67
        self.n_heads = self.num_heads
Patrick von Platen's avatar
Patrick von Platen committed
68
69
70

        if encoder_channels is not None:
            self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
Patrick von Platen's avatar
Patrick von Platen committed
71

Patrick von Platen's avatar
Patrick von Platen committed
72
73
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

Patrick von Platen's avatar
Patrick von Platen committed
74
75
76
77
78
79
80
        self.overwrite_qkv = overwrite_qkv
        if overwrite_qkv:
            in_channels = channels
            self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
            self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
            self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
            self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
Patrick von Platen's avatar
Patrick von Platen committed
81

Patrick von Platen's avatar
Patrick von Platen committed
82
        self.is_overwritten = False
83

Patrick von Platen's avatar
Patrick von Platen committed
84
85
86
87
    def set_weights(self, module):
        if self.overwrite_qkv:
            qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[:, :, :, 0]
            qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
Patrick von Platen's avatar
Patrick von Platen committed
88

Patrick von Platen's avatar
Patrick von Platen committed
89
90
91
92
93
94
            self.qkv.weight.data = qkv_weight
            self.qkv.bias.data = qkv_bias

            proj_out = zero_module(conv_nd(1, self.channels, self.channels, 1))
            proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
            proj_out.bias.data = module.proj_out.bias.data
Patrick von Platen's avatar
Patrick von Platen committed
95

Patrick von Platen's avatar
Patrick von Platen committed
96
97
98
99
100
101
102
103
104
            self.proj_out = proj_out

    def forward(self, x, encoder_out=None):
        if self.overwrite_qkv and not self.is_overwritten:
            self.set_weights(self)
            self.is_overwritten = True

        b, c, *spatial = x.shape
        hid_states = self.norm(x).view(b, c, -1)
Patrick von Platen's avatar
Patrick von Platen committed
105

Patrick von Platen's avatar
Patrick von Platen committed
106
        qkv = self.qkv(hid_states)
Patrick von Platen's avatar
Patrick von Platen committed
107
108
109
110
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
Patrick von Platen's avatar
Patrick von Platen committed
111
112
113

        if encoder_out is not None:
            encoder_kv = self.encoder_kv(encoder_out)
Patrick von Platen's avatar
Patrick von Platen committed
114
115
116
117
            assert encoder_kv.shape[1] == self.n_heads * ch * 2
            ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
            k = torch.cat([ek, k], dim=-1)
            v = torch.cat([ev, v], dim=-1)
Patrick von Platen's avatar
Patrick von Platen committed
118

Patrick von Platen's avatar
Patrick von Platen committed
119
120
121
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
Patrick von Platen's avatar
Patrick von Platen committed
122

Patrick von Platen's avatar
Patrick von Platen committed
123
        a = torch.einsum("bts,bcs->bct", weight, v)
Patrick von Platen's avatar
Patrick von Platen committed
124
125
126
127
128
        h = a.reshape(bs, -1, length)

        h = self.proj_out(h)

        return x + h.reshape(b, c, *spatial)
Patrick von Platen's avatar
Patrick von Platen committed
129
130


131
def conv_nd(dims, *args, **kwargs):
Patrick von Platen's avatar
Patrick von Platen committed
132
    """
133
    Create a 1D, 2D, or 3D convolution module.
Patrick von Platen's avatar
Patrick von Platen committed
134
    """
135
136
137
138
139
140
141
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")
Patrick von Platen's avatar
Patrick von Platen committed
142
143


144
class GroupNorm32(nn.GroupNorm):
Patrick von Platen's avatar
Patrick von Platen committed
145
146
    def __init__(self, num_groups, num_channels, swish, eps=1e-5, affine=True):
        super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine)
147
        self.swish = swish
Patrick von Platen's avatar
Patrick von Platen committed
148
149

    def forward(self, x):
150
151
152
153
154
155
        y = super().forward(x.float()).to(x.dtype)
        if self.swish == 1.0:
            y = F.silu(y)
        elif self.swish:
            y = y * F.sigmoid(y * float(self.swish))
        return y
Patrick von Platen's avatar
Patrick von Platen committed
156

157

Patrick von Platen's avatar
Patrick von Platen committed
158
def normalization(channels, swish=0.0, eps=1e-5):
Patrick von Platen's avatar
Patrick von Platen committed
159
    """
160
161
162
    Make a standard normalization layer, with an optional swish activation.

    :param channels: number of input channels. :return: an nn.Module for normalization.
Patrick von Platen's avatar
Patrick von Platen committed
163
    """
Patrick von Platen's avatar
Patrick von Platen committed
164
    return GroupNorm32(num_channels=channels, num_groups=32, swish=swish, eps=eps, affine=True)
Patrick von Platen's avatar
Patrick von Platen committed
165
166


167
168
169
170
171
172
173
def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module
Patrick von Platen's avatar
Patrick von Platen committed
174
175
176


# unet_score_estimation.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# class AttnBlockpp(nn.Module):
#    """Channel-wise self-attention block. Modified from DDPM."""
#
#    def __init__(self, channels, skip_rescale=False, init_scale=0.0):
#        super().__init__()
#        self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
#        self.NIN_0 = NIN(channels, channels)
#        self.NIN_1 = NIN(channels, channels)
#        self.NIN_2 = NIN(channels, channels)
#        self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
#        self.skip_rescale = skip_rescale
#
#    def forward(self, x):
#        B, C, H, W = x.shape
#        h = self.GroupNorm_0(x)
#        q = self.NIN_0(h)
#        k = self.NIN_1(h)
#        v = self.NIN_2(h)
#
#        w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
#        w = torch.reshape(w, (B, H, W, H * W))
#        w = F.softmax(w, dim=-1)
#        w = torch.reshape(w, (B, H, W, H, W))
#        h = torch.einsum("bhwij,bcij->bchw", w, v)
#        h = self.NIN_3(h)
#        if not self.skip_rescale:
#            return x + h
#        else:
#            return (x + h) / np.sqrt(2.0)