attention.py 10.4 KB
Newer Older
1
import math
2
3
4
from inspect import isfunction
from typing import Any, Optional

5
6
7
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
Fazzie's avatar
Fazzie committed
8
from ldm.modules.diffusionmodules.util import checkpoint
9
from torch import einsum, nn
10
11

try:
Fazzie's avatar
Fazzie committed
12
13
    import xformers
    import xformers.ops
14

Fazzie's avatar
Fazzie committed
15
    XFORMERS_IS_AVAILBLE = True
16
except:
Fazzie's avatar
Fazzie committed
17
    XFORMERS_IS_AVAILBLE = False
18
19
20
21
22
23
24


def exists(val):
    return val is not None


def uniq(arr):
25
    return {el: True for el in arr}.keys()
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
57
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
58
59
60
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
61
62
63
        project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


class SpatialSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
88
89
90
91
        self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
92
93
94
95
96
97
98
99
100

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
101
102
103
104
        b, c, h, w = q.shape
        q = rearrange(q, "b c h w -> b (h w) c")
        k = rearrange(k, "b c h w -> b c (h w)")
        w_ = torch.einsum("bij,bjk->bik", q, k)
105

106
        w_ = w_ * (int(c) ** (-0.5))
107
108
109
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
110
111
112
113
        v = rearrange(v, "b c h w -> b c (h w)")
        w_ = rearrange(w_, "b i j -> b j i")
        h_ = torch.einsum("bij,bjk->bik", v, w_)
        h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
114
115
        h_ = self.proj_out(h_)

116
        return x + h_
117
118
119


class CrossAttention(nn.Module):
120
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
121
122
123
124
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

125
        self.scale = dim_head**-0.5
126
127
128
129
130
131
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

132
        self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
133
134

    def forward(self, x, context=None, mask=None):
Fazzie's avatar
Fazzie committed
135
136
        h = self.heads

137
138
139
140
141
        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

142
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
Fazzie's avatar
Fazzie committed
143

144
        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
Fazzie's avatar
Fazzie committed
145
146
        del q, k

147
        if exists(mask):
148
            mask = rearrange(mask, "b ... -> b (...)")
149
            max_neg_value = -torch.finfo(sim.dtype).max
150
            mask = repeat(mask, "b j -> (b h) () j", h=h)
151
            sim.masked_fill_(~mask, max_neg_value)
Fazzie's avatar
Fazzie committed
152

153
        # attention, what we cannot get enough of
Fazzie's avatar
Fazzie committed
154
155
        sim = sim.softmax(dim=-1)

156
157
        out = einsum("b i j, b j d -> b i d", sim, v)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
Fazzie's avatar
Fazzie committed
158
159
160
161
162
163
164
        return self.to_out(out)


class MemoryEfficientCrossAttention(nn.Module):
    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
165
166
167
168
        print(
            f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
            f"{heads} heads."
        )
Fazzie's avatar
Fazzie committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.heads = heads
        self.dim_head = dim_head

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
        self.attention_op: Optional[Any] = None

    def forward(self, x, context=None, mask=None):
        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        b, _, _ = q.shape
        q, k, v = map(
            lambda t: t.unsqueeze(3)
            .reshape(b, t.shape[1], self.heads, self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b * self.heads, t.shape[1], self.dim_head)
            .contiguous(),
            (q, k, v),
        )

        # actually compute the attention, what we cannot get enough of
        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)

        if exists(mask):
            raise NotImplementedError
        out = (
            out.unsqueeze(0)
            .reshape(b, self.heads, out.shape[1], self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b, out.shape[1], self.heads * self.dim_head)
        )
        return self.to_out(out)
210
211
212


class BasicTransformerBlock(nn.Module):
Fazzie's avatar
Fazzie committed
213
214
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # vanilla attention
215
        "softmax-xformers": MemoryEfficientCrossAttention,
Fazzie's avatar
Fazzie committed
216
    }
217
218
219
220
221
222
223
224
225
226
227
228

    def __init__(
        self,
        dim,
        n_heads,
        d_head,
        dropout=0.0,
        context_dim=None,
        gated_ff=True,
        checkpoint=True,
        disable_self_attn=False,
    ):
229
        super().__init__()
Fazzie's avatar
Fazzie committed
230
231
232
233
        attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
        assert attn_mode in self.ATTENTION_MODES
        attn_cls = self.ATTENTION_MODES[attn_mode]
        self.disable_self_attn = disable_self_attn
234
235
236
237
238
239
240
        self.attn1 = attn_cls(
            query_dim=dim,
            heads=n_heads,
            dim_head=d_head,
            dropout=dropout,
            context_dim=context_dim if self.disable_self_attn else None,
        )  # is a self-attention if not self.disable_self_attn
241
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
242
243
244
        self.attn2 = attn_cls(
            query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
        )  # is self-attn if context is none
245
246
247
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
Fazzie's avatar
Fazzie committed
248
        self.checkpoint = checkpoint
249
250

    def forward(self, x, context=None):
Fazzie's avatar
Fazzie committed
251
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
252
253

    def _forward(self, x, context=None):
Fazzie's avatar
Fazzie committed
254
        x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
255
256
257
258
259
260
261
262
263
264
265
266
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x


class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
Fazzie's avatar
Fazzie committed
267
    NEW: use_linear for more efficiency instead of the 1x1 convs
268
    """
269
270
271
272
273
274
275
276
277
278
279
280
281

    def __init__(
        self,
        in_channels,
        n_heads,
        d_head,
        depth=1,
        dropout=0.0,
        context_dim=None,
        disable_self_attn=False,
        use_linear=False,
        use_checkpoint=True,
    ):
282
        super().__init__()
Fazzie's avatar
Fazzie committed
283
284
        if exists(context_dim) and not isinstance(context_dim, list):
            context_dim = [context_dim]
285
286
287
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)
Fazzie's avatar
Fazzie committed
288
        if not use_linear:
289
            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
Fazzie's avatar
Fazzie committed
290
291
        else:
            self.proj_in = nn.Linear(in_channels, inner_dim)
292
293

        self.transformer_blocks = nn.ModuleList(
294
295
296
297
298
299
300
301
302
303
304
305
            [
                BasicTransformerBlock(
                    inner_dim,
                    n_heads,
                    d_head,
                    dropout=dropout,
                    context_dim=context_dim[d],
                    disable_self_attn=disable_self_attn,
                    checkpoint=use_checkpoint,
                )
                for d in range(depth)
            ]
306
        )
Fazzie's avatar
Fazzie committed
307
        if not use_linear:
308
            self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
Fazzie's avatar
Fazzie committed
309
310
311
        else:
            self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
        self.use_linear = use_linear
312
313
314

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
Fazzie's avatar
Fazzie committed
315
316
        if not isinstance(context, list):
            context = [context]
317
318
319
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
Fazzie's avatar
Fazzie committed
320
321
        if not self.use_linear:
            x = self.proj_in(x)
322
        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
Fazzie's avatar
Fazzie committed
323
324
325
326
327
328
        if self.use_linear:
            x = self.proj_in(x)
        for i, block in enumerate(self.transformer_blocks):
            x = block(x, context=context[i])
        if self.use_linear:
            x = self.proj_out(x)
329
        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
Fazzie's avatar
Fazzie committed
330
331
332
        if not self.use_linear:
            x = self.proj_out(x)
        return x + x_in