attention.py 14.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import math
Kashif Rasul's avatar
Kashif Rasul committed
15
from typing import Optional
16
17

import torch
Patrick von Platen's avatar
Patrick von Platen committed
18
import torch.nn.functional as F
19
20
from torch import nn

Will Berman's avatar
Will Berman committed
21
from ..utils.import_utils import is_xformers_available
22
from .cross_attention import CrossAttention
23
24
25
26
27
28
29
30


if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None

31

32
class AttentionBlock(nn.Module):
Patrick von Platen's avatar
Patrick von Platen committed
33
34
35
36
    """
    An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
    to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Kashif Rasul's avatar
Kashif Rasul committed
37
38
39
    Uses three q, k, v linear layers to compute attention.

    Parameters:
Will Berman's avatar
Will Berman committed
40
41
        channels (`int`): The number of channels in the input and output.
        num_head_channels (`int`, *optional*):
Kashif Rasul's avatar
Kashif Rasul committed
42
            The number of channels in each head. If None, then `num_heads` = 1.
Will Berman's avatar
Will Berman committed
43
44
45
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
        rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
        eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
Patrick von Platen's avatar
Patrick von Platen committed
46
47
    """

Will Berman's avatar
Will Berman committed
48
49
    # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore

Patrick von Platen's avatar
Patrick von Platen committed
50
51
    def __init__(
        self,
Kashif Rasul's avatar
Kashif Rasul committed
52
53
        channels: int,
        num_head_channels: Optional[int] = None,
Will Berman's avatar
Will Berman committed
54
        norm_num_groups: int = 32,
Kashif Rasul's avatar
Kashif Rasul committed
55
56
        rescale_output_factor: float = 1.0,
        eps: float = 1e-5,
Patrick von Platen's avatar
Patrick von Platen committed
57
58
59
60
    ):
        super().__init__()
        self.channels = channels

Patrick von Platen's avatar
Patrick von Platen committed
61
        self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
Patrick von Platen's avatar
Patrick von Platen committed
62
        self.num_head_size = num_head_channels
Will Berman's avatar
Will Berman committed
63
        self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
Patrick von Platen's avatar
Patrick von Platen committed
64
65
66
67
68
69
70

        # define q,k,v as linear layers
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)

        self.rescale_output_factor = rescale_output_factor
Patrick von Platen's avatar
Patrick von Platen committed
71
        self.proj_attn = nn.Linear(channels, channels, 1)
Patrick von Platen's avatar
Patrick von Platen committed
72

73
74
        self._use_memory_efficient_attention_xformers = False

75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def reshape_heads_to_batch_dim(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
        return tensor

    def reshape_batch_dim_to_heads(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

89
    def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
90
91
92
93
94
95
96
97
98
        if use_memory_efficient_attention_xformers:
            if not is_xformers_available():
                raise ModuleNotFoundError(
                    "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
                    " xformers",
                    name="xformers",
                )
            elif not torch.cuda.is_available():
                raise ValueError(
Patrick von Platen's avatar
Patrick von Platen committed
99
100
                    "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
                    " only available for GPU "
101
                )
102
103
104
105
106
107
108
109
110
111
112
            else:
                try:
                    # Make sure we can run the memory efficient attention
                    _ = xformers.ops.memory_efficient_attention(
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                    )
                except Exception as e:
                    raise e
        self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
113

Patrick von Platen's avatar
Patrick von Platen committed
114
115
116
117
118
119
    def forward(self, hidden_states):
        residual = hidden_states
        batch, channel, height, width = hidden_states.shape

        # norm
        hidden_states = self.group_norm(hidden_states)
120

Patrick von Platen's avatar
Patrick von Platen committed
121
122
123
124
125
126
127
        hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

        # proj to q, k, v
        query_proj = self.query(hidden_states)
        key_proj = self.key(hidden_states)
        value_proj = self.value(hidden_states)

128
        scale = 1 / math.sqrt(self.channels / self.num_heads)
Patrick von Platen's avatar
Patrick von Platen committed
129

Suraj Patil's avatar
Suraj Patil committed
130
131
132
133
        query_proj = self.reshape_heads_to_batch_dim(query_proj)
        key_proj = self.reshape_heads_to_batch_dim(key_proj)
        value_proj = self.reshape_heads_to_batch_dim(value_proj)

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        if self._use_memory_efficient_attention_xformers:
            # Memory efficient attention
            hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
            hidden_states = hidden_states.to(query_proj.dtype)
        else:
            attention_scores = torch.baddbmm(
                torch.empty(
                    query_proj.shape[0],
                    query_proj.shape[1],
                    key_proj.shape[1],
                    dtype=query_proj.dtype,
                    device=query_proj.device,
                ),
                query_proj,
                key_proj.transpose(-1, -2),
                beta=0,
                alpha=scale,
            )
            attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
            hidden_states = torch.bmm(attention_probs, value_proj)
Patrick von Platen's avatar
Patrick von Platen committed
154

Suraj Patil's avatar
Suraj Patil committed
155
156
        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
Patrick von Platen's avatar
Patrick von Platen committed
157
158

        # compute next hidden_states
159
        hidden_states = self.proj_attn(hidden_states)
Will Berman's avatar
Will Berman committed
160

Patrick von Platen's avatar
Patrick von Platen committed
161
162
163
164
165
166
        hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

        # res connect and rescale
        hidden_states = (hidden_states + residual) / self.rescale_output_factor
        return hidden_states

Patrick von Platen's avatar
Patrick von Platen committed
167

Patrick von Platen's avatar
Patrick von Platen committed
168
class BasicTransformerBlock(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
169
170
171
172
    r"""
    A basic Transformer block.

    Parameters:
Will Berman's avatar
Will Berman committed
173
174
175
176
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
Will Berman's avatar
Will Berman committed
177
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
Will Berman's avatar
Will Berman committed
178
179
180
181
182
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm (:
            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:
            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
Kashif Rasul's avatar
Kashif Rasul committed
183
184
185
186
187
    """

    def __init__(
        self,
        dim: int,
Will Berman's avatar
Will Berman committed
188
189
        num_attention_heads: int,
        attention_head_dim: int,
Kashif Rasul's avatar
Kashif Rasul committed
190
        dropout=0.0,
Will Berman's avatar
Will Berman committed
191
192
193
194
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
195
        only_cross_attention: bool = False,
196
        upcast_attention: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
197
    ):
Patrick von Platen's avatar
Patrick von Platen committed
198
        super().__init__()
199
        self.only_cross_attention = only_cross_attention
200
201
202
        self.use_ada_layer_norm = num_embeds_ada_norm is not None

        # 1. Self-Attn
Patrick von Platen's avatar
Patrick von Platen committed
203
        self.attn1 = CrossAttention(
Will Berman's avatar
Will Berman committed
204
205
206
207
208
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
209
            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
210
            upcast_attention=upcast_attention,
211
212
        )

Will Berman's avatar
Will Berman committed
213
214
        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)

215
216
217
218
219
220
221
222
223
        # 2. Cross-Attn
        if cross_attention_dim is not None:
            self.attn2 = CrossAttention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
224
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
225
            )  # is self-attn if encoder_hidden_states is none
Will Berman's avatar
Will Berman committed
226
        else:
227
228
229
230
231
232
233
234
235
236
            self.attn2 = None

        self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)

        if cross_attention_dim is not None:
            self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
        else:
            self.norm2 = None

        # 3. Feed-forward
Patrick von Platen's avatar
Patrick von Platen committed
237
238
        self.norm3 = nn.LayerNorm(dim)

239
240
241
242
243
244
245
246
    def forward(
        self,
        hidden_states,
        encoder_hidden_states=None,
        timestep=None,
        attention_mask=None,
        cross_attention_kwargs=None,
    ):
Will Berman's avatar
Will Berman committed
247
248
249
250
        # 1. Self-Attention
        norm_hidden_states = (
            self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
        )
251
252
253
254
255
256
257
258
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
        hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
259

260
261
262
263
264
        if self.attn2 is not None:
            # 2. Cross-Attention
            norm_hidden_states = (
                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
            )
265
266
267
268
269
            attn_output = self.attn2(
                norm_hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                **cross_attention_kwargs,
Will Berman's avatar
Will Berman committed
270
            )
271
            hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
272
273

        # 3. Feed-forward
274
        hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
Will Berman's avatar
Will Berman committed
275

276
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
277
278
279


class FeedForward(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
280
281
282
283
    r"""
    A feed-forward layer.

    Parameters:
Will Berman's avatar
Will Berman committed
284
285
286
287
288
        dim (`int`): The number of channels in the input.
        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
Kashif Rasul's avatar
Kashif Rasul committed
289
290
291
    """

    def __init__(
Will Berman's avatar
Will Berman committed
292
293
294
295
296
297
        self,
        dim: int,
        dim_out: Optional[int] = None,
        mult: int = 4,
        dropout: float = 0.0,
        activation_fn: str = "geglu",
Kashif Rasul's avatar
Kashif Rasul committed
298
    ):
Patrick von Platen's avatar
Patrick von Platen committed
299
300
        super().__init__()
        inner_dim = int(dim * mult)
301
        dim_out = dim_out if dim_out is not None else dim
Patrick von Platen's avatar
Patrick von Platen committed
302

303
304
305
306
        if activation_fn == "gelu":
            act_fn = GELU(dim, inner_dim)
        elif activation_fn == "geglu":
            act_fn = GEGLU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
307
        elif activation_fn == "geglu-approximate":
308
            act_fn = ApproximateGELU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
309
310

        self.net = nn.ModuleList([])
311
        # project in
312
        self.net.append(act_fn)
313
314
315
316
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
        self.net.append(nn.Linear(inner_dim, dim_out))
Patrick von Platen's avatar
Patrick von Platen committed
317

318
    def forward(self, hidden_states):
319
320
321
        for module in self.net:
            hidden_states = module(hidden_states)
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
322

Patrick von Platen's avatar
Patrick von Platen committed
323

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
class GELU(nn.Module):
    r"""
    GELU activation function
    """

    def __init__(self, dim_in: int, dim_out: int):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)

    def gelu(self, gate):
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

    def forward(self, hidden_states):
        hidden_states = self.proj(hidden_states)
        hidden_states = self.gelu(hidden_states)
        return hidden_states


Patrick von Platen's avatar
Patrick von Platen committed
345
346
# feedforward
class GEGLU(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
347
348
349
350
    r"""
    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.

    Parameters:
Will Berman's avatar
Will Berman committed
351
352
        dim_in (`int`): The number of channels in the input.
        dim_out (`int`): The number of channels in the output.
Kashif Rasul's avatar
Kashif Rasul committed
353
354
355
    """

    def __init__(self, dim_in: int, dim_out: int):
Patrick von Platen's avatar
Patrick von Platen committed
356
357
358
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

359
360
361
362
363
364
    def gelu(self, gate):
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

365
366
    def forward(self, hidden_states):
        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
367
        return hidden_states * self.gelu(gate)
Will Berman's avatar
Will Berman committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402


class ApproximateGELU(nn.Module):
    """
    The approximate form of Gaussian Error Linear Unit (GELU)

    For more details, see section 2: https://arxiv.org/abs/1606.08415
    """

    def __init__(self, dim_in: int, dim_out: int):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        x = self.proj(x)
        return x * torch.sigmoid(1.702 * x)


class AdaLayerNorm(nn.Module):
    """
    Norm layer modified to incorporate timestep embeddings.
    """

    def __init__(self, embedding_dim, num_embeddings):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings, embedding_dim)
        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)

    def forward(self, x, timestep):
        emb = self.linear(self.silu(self.emb(timestep)))
        scale, shift = torch.chunk(emb, 2)
        x = self.norm(x) * (1 + scale) + shift
        return x