attention.py 14.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import math
Kashif Rasul's avatar
Kashif Rasul committed
15
from typing import Optional
16
17

import torch
Patrick von Platen's avatar
Patrick von Platen committed
18
import torch.nn.functional as F
19
20
from torch import nn

Will Berman's avatar
Will Berman committed
21
from ..utils.import_utils import is_xformers_available
22
from .cross_attention import CrossAttention
23
24
25
26
27
28
29
30


if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None

31

32
class AttentionBlock(nn.Module):
Patrick von Platen's avatar
Patrick von Platen committed
33
34
35
36
    """
    An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
    to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Kashif Rasul's avatar
Kashif Rasul committed
37
38
39
    Uses three q, k, v linear layers to compute attention.

    Parameters:
Will Berman's avatar
Will Berman committed
40
41
        channels (`int`): The number of channels in the input and output.
        num_head_channels (`int`, *optional*):
Kashif Rasul's avatar
Kashif Rasul committed
42
            The number of channels in each head. If None, then `num_heads` = 1.
Will Berman's avatar
Will Berman committed
43
44
45
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
        rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
        eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
Patrick von Platen's avatar
Patrick von Platen committed
46
47
    """

Will Berman's avatar
Will Berman committed
48
49
    # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore

Patrick von Platen's avatar
Patrick von Platen committed
50
51
    def __init__(
        self,
Kashif Rasul's avatar
Kashif Rasul committed
52
53
        channels: int,
        num_head_channels: Optional[int] = None,
Will Berman's avatar
Will Berman committed
54
        norm_num_groups: int = 32,
Kashif Rasul's avatar
Kashif Rasul committed
55
56
        rescale_output_factor: float = 1.0,
        eps: float = 1e-5,
Patrick von Platen's avatar
Patrick von Platen committed
57
58
59
60
    ):
        super().__init__()
        self.channels = channels

Patrick von Platen's avatar
Patrick von Platen committed
61
        self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
Patrick von Platen's avatar
Patrick von Platen committed
62
        self.num_head_size = num_head_channels
Will Berman's avatar
Will Berman committed
63
        self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
Patrick von Platen's avatar
Patrick von Platen committed
64
65
66
67
68
69
70

        # define q,k,v as linear layers
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)

        self.rescale_output_factor = rescale_output_factor
Patrick von Platen's avatar
Patrick von Platen committed
71
        self.proj_attn = nn.Linear(channels, channels, 1)
Patrick von Platen's avatar
Patrick von Platen committed
72

73
74
        self._use_memory_efficient_attention_xformers = False

75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def reshape_heads_to_batch_dim(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
        return tensor

    def reshape_batch_dim_to_heads(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

89
    def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
90
91
92
        if use_memory_efficient_attention_xformers:
            if not is_xformers_available():
                raise ModuleNotFoundError(
93
94
95
96
                    (
                        "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
                        " xformers"
                    ),
97
98
99
100
                    name="xformers",
                )
            elif not torch.cuda.is_available():
                raise ValueError(
Patrick von Platen's avatar
Patrick von Platen committed
101
102
                    "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
                    " only available for GPU "
103
                )
104
105
106
107
108
109
110
111
112
113
114
            else:
                try:
                    # Make sure we can run the memory efficient attention
                    _ = xformers.ops.memory_efficient_attention(
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                    )
                except Exception as e:
                    raise e
        self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
115

Patrick von Platen's avatar
Patrick von Platen committed
116
117
118
119
120
121
    def forward(self, hidden_states):
        residual = hidden_states
        batch, channel, height, width = hidden_states.shape

        # norm
        hidden_states = self.group_norm(hidden_states)
122

Patrick von Platen's avatar
Patrick von Platen committed
123
124
125
126
127
128
129
        hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

        # proj to q, k, v
        query_proj = self.query(hidden_states)
        key_proj = self.key(hidden_states)
        value_proj = self.value(hidden_states)

130
        scale = 1 / math.sqrt(self.channels / self.num_heads)
Patrick von Platen's avatar
Patrick von Platen committed
131

Suraj Patil's avatar
Suraj Patil committed
132
133
134
135
        query_proj = self.reshape_heads_to_batch_dim(query_proj)
        key_proj = self.reshape_heads_to_batch_dim(key_proj)
        value_proj = self.reshape_heads_to_batch_dim(value_proj)

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        if self._use_memory_efficient_attention_xformers:
            # Memory efficient attention
            hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
            hidden_states = hidden_states.to(query_proj.dtype)
        else:
            attention_scores = torch.baddbmm(
                torch.empty(
                    query_proj.shape[0],
                    query_proj.shape[1],
                    key_proj.shape[1],
                    dtype=query_proj.dtype,
                    device=query_proj.device,
                ),
                query_proj,
                key_proj.transpose(-1, -2),
                beta=0,
                alpha=scale,
            )
            attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
            hidden_states = torch.bmm(attention_probs, value_proj)
Patrick von Platen's avatar
Patrick von Platen committed
156

Suraj Patil's avatar
Suraj Patil committed
157
158
        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
Patrick von Platen's avatar
Patrick von Platen committed
159
160

        # compute next hidden_states
161
        hidden_states = self.proj_attn(hidden_states)
Will Berman's avatar
Will Berman committed
162

Patrick von Platen's avatar
Patrick von Platen committed
163
164
165
166
167
168
        hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

        # res connect and rescale
        hidden_states = (hidden_states + residual) / self.rescale_output_factor
        return hidden_states

Patrick von Platen's avatar
Patrick von Platen committed
169

Patrick von Platen's avatar
Patrick von Platen committed
170
class BasicTransformerBlock(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
171
172
173
174
    r"""
    A basic Transformer block.

    Parameters:
Will Berman's avatar
Will Berman committed
175
176
177
178
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
Will Berman's avatar
Will Berman committed
179
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
Will Berman's avatar
Will Berman committed
180
181
182
183
184
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm (:
            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:
            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
Kashif Rasul's avatar
Kashif Rasul committed
185
186
187
188
189
    """

    def __init__(
        self,
        dim: int,
Will Berman's avatar
Will Berman committed
190
191
        num_attention_heads: int,
        attention_head_dim: int,
Kashif Rasul's avatar
Kashif Rasul committed
192
        dropout=0.0,
Will Berman's avatar
Will Berman committed
193
194
195
196
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
197
        only_cross_attention: bool = False,
198
        upcast_attention: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
199
    ):
Patrick von Platen's avatar
Patrick von Platen committed
200
        super().__init__()
201
        self.only_cross_attention = only_cross_attention
202
203
204
        self.use_ada_layer_norm = num_embeds_ada_norm is not None

        # 1. Self-Attn
Patrick von Platen's avatar
Patrick von Platen committed
205
        self.attn1 = CrossAttention(
Will Berman's avatar
Will Berman committed
206
207
208
209
210
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
211
            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
212
            upcast_attention=upcast_attention,
213
214
        )

Will Berman's avatar
Will Berman committed
215
216
        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)

217
218
219
220
221
222
223
224
225
        # 2. Cross-Attn
        if cross_attention_dim is not None:
            self.attn2 = CrossAttention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
226
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
227
            )  # is self-attn if encoder_hidden_states is none
Will Berman's avatar
Will Berman committed
228
        else:
229
230
231
232
233
234
235
236
237
238
            self.attn2 = None

        self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)

        if cross_attention_dim is not None:
            self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
        else:
            self.norm2 = None

        # 3. Feed-forward
Patrick von Platen's avatar
Patrick von Platen committed
239
240
        self.norm3 = nn.LayerNorm(dim)

241
242
243
244
245
246
247
248
    def forward(
        self,
        hidden_states,
        encoder_hidden_states=None,
        timestep=None,
        attention_mask=None,
        cross_attention_kwargs=None,
    ):
Will Berman's avatar
Will Berman committed
249
250
251
252
        # 1. Self-Attention
        norm_hidden_states = (
            self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
        )
253
254
255
256
257
258
259
260
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
        hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
261

262
263
264
265
266
        if self.attn2 is not None:
            # 2. Cross-Attention
            norm_hidden_states = (
                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
            )
267
268
269
270
271
            attn_output = self.attn2(
                norm_hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                **cross_attention_kwargs,
Will Berman's avatar
Will Berman committed
272
            )
273
            hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
274
275

        # 3. Feed-forward
276
        hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
Will Berman's avatar
Will Berman committed
277

278
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
279
280
281


class FeedForward(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
282
283
284
285
    r"""
    A feed-forward layer.

    Parameters:
Will Berman's avatar
Will Berman committed
286
287
288
289
290
        dim (`int`): The number of channels in the input.
        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
Kashif Rasul's avatar
Kashif Rasul committed
291
292
293
    """

    def __init__(
Will Berman's avatar
Will Berman committed
294
295
296
297
298
299
        self,
        dim: int,
        dim_out: Optional[int] = None,
        mult: int = 4,
        dropout: float = 0.0,
        activation_fn: str = "geglu",
Kashif Rasul's avatar
Kashif Rasul committed
300
    ):
Patrick von Platen's avatar
Patrick von Platen committed
301
302
        super().__init__()
        inner_dim = int(dim * mult)
303
        dim_out = dim_out if dim_out is not None else dim
Patrick von Platen's avatar
Patrick von Platen committed
304

305
306
307
308
        if activation_fn == "gelu":
            act_fn = GELU(dim, inner_dim)
        elif activation_fn == "geglu":
            act_fn = GEGLU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
309
        elif activation_fn == "geglu-approximate":
310
            act_fn = ApproximateGELU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
311
312

        self.net = nn.ModuleList([])
313
        # project in
314
        self.net.append(act_fn)
315
316
317
318
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
        self.net.append(nn.Linear(inner_dim, dim_out))
Patrick von Platen's avatar
Patrick von Platen committed
319

320
    def forward(self, hidden_states):
321
322
323
        for module in self.net:
            hidden_states = module(hidden_states)
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
324

Patrick von Platen's avatar
Patrick von Platen committed
325

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
class GELU(nn.Module):
    r"""
    GELU activation function
    """

    def __init__(self, dim_in: int, dim_out: int):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)

    def gelu(self, gate):
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

    def forward(self, hidden_states):
        hidden_states = self.proj(hidden_states)
        hidden_states = self.gelu(hidden_states)
        return hidden_states


Patrick von Platen's avatar
Patrick von Platen committed
347
348
# feedforward
class GEGLU(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
349
350
351
352
    r"""
    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.

    Parameters:
Will Berman's avatar
Will Berman committed
353
354
        dim_in (`int`): The number of channels in the input.
        dim_out (`int`): The number of channels in the output.
Kashif Rasul's avatar
Kashif Rasul committed
355
356
357
    """

    def __init__(self, dim_in: int, dim_out: int):
Patrick von Platen's avatar
Patrick von Platen committed
358
359
360
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

361
362
363
364
365
366
    def gelu(self, gate):
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

367
368
    def forward(self, hidden_states):
        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
369
        return hidden_states * self.gelu(gate)
Will Berman's avatar
Will Berman committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404


class ApproximateGELU(nn.Module):
    """
    The approximate form of Gaussian Error Linear Unit (GELU)

    For more details, see section 2: https://arxiv.org/abs/1606.08415
    """

    def __init__(self, dim_in: int, dim_out: int):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        x = self.proj(x)
        return x * torch.sigmoid(1.702 * x)


class AdaLayerNorm(nn.Module):
    """
    Norm layer modified to incorporate timestep embeddings.
    """

    def __init__(self, embedding_dim, num_embeddings):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings, embedding_dim)
        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)

    def forward(self, x, timestep):
        emb = self.linear(self.silu(self.emb(timestep)))
        scale, shift = torch.chunk(emb, 2)
        x = self.norm(x) * (1 + scale) + shift
        return x