attention.py 17.3 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from typing import Any, Dict, Optional
15
16
17
18

import torch
from torch import nn

19
from ..utils import USE_PEFT_BACKEND
Dhruv Nair's avatar
Dhruv Nair committed
20
from ..utils.torch_utils import maybe_allow_in_graph
21
from .activations import GEGLU, GELU, ApproximateGELU
Patrick von Platen's avatar
Patrick von Platen committed
22
from .attention_processor import Attention
Dhruv Nair's avatar
Dhruv Nair committed
23
from .embeddings import SinusoidalPositionalEmbedding
24
from .lora import LoRACompatibleLinear
25
from .normalization import AdaLayerNorm, AdaLayerNormZero
26
27


28
29
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
30
31
32
33
34
35
36
37
38
39
40
    r"""
    A gated self-attention dense layer that combines visual features and object features.

    Parameters:
        query_dim (`int`): The number of channels in the query.
        context_dim (`int`): The number of channels in the context.
        n_heads (`int`): The number of heads to use for attention.
        d_head (`int`): The number of channels in each head.
    """

    def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        super().__init__()

        # we need a linear projection since we need cat visual feature and obj feature
        self.linear = nn.Linear(context_dim, query_dim)

        self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
        self.ff = FeedForward(query_dim, activation_fn="geglu")

        self.norm1 = nn.LayerNorm(query_dim)
        self.norm2 = nn.LayerNorm(query_dim)

        self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
        self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))

        self.enabled = True

57
    def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
58
59
60
61
62
63
64
65
66
67
68
69
        if not self.enabled:
            return x

        n_visual = x.shape[1]
        objs = self.linear(objs)

        x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
        x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))

        return x


70
@maybe_allow_in_graph
Patrick von Platen's avatar
Patrick von Platen committed
71
class BasicTransformerBlock(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
72
73
74
75
    r"""
    A basic Transformer block.

    Parameters:
Will Berman's avatar
Will Berman committed
76
77
78
79
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
Will Berman's avatar
Will Berman committed
80
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
Will Berman's avatar
Will Berman committed
81
82
83
84
85
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm (:
            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:
            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        only_cross_attention (`bool`, *optional*):
            Whether to use only cross-attention layers. In this case two cross attention layers are used.
        double_self_attention (`bool`, *optional*):
            Whether to use two self-attention layers. In this case no cross attention layers are used.
        upcast_attention (`bool`, *optional*):
            Whether to upcast the attention computation to float32. This is useful for mixed precision training.
        norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
            Whether to use learnable elementwise affine parameters for normalization.
        norm_type (`str`, *optional*, defaults to `"layer_norm"`):
            The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
        final_dropout (`bool` *optional*, defaults to False):
            Whether to apply a final dropout after the last feed-forward layer.
        attention_type (`str`, *optional*, defaults to `"default"`):
            The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
Dhruv Nair's avatar
Dhruv Nair committed
100
101
102
103
        positional_embeddings (`str`, *optional*, defaults to `None`):
            The type of positional embeddings to apply to.
        num_positional_embeddings (`int`, *optional*, defaults to `None`):
            The maximum number of positional embeddings to apply.
Kashif Rasul's avatar
Kashif Rasul committed
104
105
106
107
108
    """

    def __init__(
        self,
        dim: int,
Will Berman's avatar
Will Berman committed
109
110
        num_attention_heads: int,
        attention_head_dim: int,
Kashif Rasul's avatar
Kashif Rasul committed
111
        dropout=0.0,
Will Berman's avatar
Will Berman committed
112
113
114
115
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
116
        only_cross_attention: bool = False,
117
        double_self_attention: bool = False,
118
        upcast_attention: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
119
        norm_elementwise_affine: bool = True,
Sayak Paul's avatar
Sayak Paul committed
120
121
        norm_type: str = "layer_norm",  # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
        norm_eps: float = 1e-5,
Kashif Rasul's avatar
Kashif Rasul committed
122
        final_dropout: bool = False,
123
        attention_type: str = "default",
Dhruv Nair's avatar
Dhruv Nair committed
124
125
        positional_embeddings: Optional[str] = None,
        num_positional_embeddings: Optional[int] = None,
Kashif Rasul's avatar
Kashif Rasul committed
126
    ):
Patrick von Platen's avatar
Patrick von Platen committed
127
        super().__init__()
128
        self.only_cross_attention = only_cross_attention
Kashif Rasul's avatar
Kashif Rasul committed
129
130
131

        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
Sayak Paul's avatar
Sayak Paul committed
132
133
        self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
        self.use_layer_norm = norm_type == "layer_norm"
Kashif Rasul's avatar
Kashif Rasul committed
134
135
136
137
138
139

        if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
            raise ValueError(
                f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
                f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
            )
140

Dhruv Nair's avatar
Dhruv Nair committed
141
142
143
144
145
146
147
148
149
150
        if positional_embeddings and (num_positional_embeddings is None):
            raise ValueError(
                "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
            )

        if positional_embeddings == "sinusoidal":
            self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
        else:
            self.pos_embed = None

151
        # Define 3 blocks. Each block has its own normalization layer.
152
        # 1. Self-Attn
153
154
155
156
157
        if self.use_ada_layer_norm:
            self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
        elif self.use_ada_layer_norm_zero:
            self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
        else:
Sayak Paul's avatar
Sayak Paul committed
158
159
            self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

Patrick von Platen's avatar
Patrick von Platen committed
160
        self.attn1 = Attention(
Will Berman's avatar
Will Berman committed
161
162
163
164
165
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
166
            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
167
            upcast_attention=upcast_attention,
168
169
        )

170
        # 2. Cross-Attn
171
        if cross_attention_dim is not None or double_self_attention:
172
173
174
175
176
177
            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
            # the second cross attention block.
            self.norm2 = (
                AdaLayerNorm(dim, num_embeds_ada_norm)
                if self.use_ada_layer_norm
Sayak Paul's avatar
Sayak Paul committed
178
                else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
179
            )
Patrick von Platen's avatar
Patrick von Platen committed
180
            self.attn2 = Attention(
181
                query_dim=dim,
182
                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
183
184
185
186
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
187
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
188
            )  # is self-attn if encoder_hidden_states is none
189
190
        else:
            self.norm2 = None
191
            self.attn2 = None
192
193

        # 3. Feed-forward
Sayak Paul's avatar
Sayak Paul committed
194
195
196
        if not self.use_ada_layer_norm_single:
            self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

197
        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
Patrick von Platen's avatar
Patrick von Platen committed
198

199
        # 4. Fuser
200
        if attention_type == "gated" or attention_type == "gated-text-image":
201
202
            self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)

Sayak Paul's avatar
Sayak Paul committed
203
204
205
206
        # 5. Scale-shift for PixArt-Alpha.
        if self.use_ada_layer_norm_single:
            self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)

207
208
209
210
211
212
213
214
215
        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = 0

    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        self._chunk_dim = dim

216
217
    def forward(
        self,
218
219
220
221
222
223
224
        hidden_states: torch.FloatTensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        timestep: Optional[torch.LongTensor] = None,
        cross_attention_kwargs: Dict[str, Any] = None,
        class_labels: Optional[torch.LongTensor] = None,
225
    ) -> torch.FloatTensor:
226
        # Notice that normalization is always applied before the real computation in the following blocks.
227
        # 0. Self-Attention
Sayak Paul's avatar
Sayak Paul committed
228
229
        batch_size = hidden_states.shape[0]

Kashif Rasul's avatar
Kashif Rasul committed
230
231
232
233
234
235
        if self.use_ada_layer_norm:
            norm_hidden_states = self.norm1(hidden_states, timestep)
        elif self.use_ada_layer_norm_zero:
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
            )
Sayak Paul's avatar
Sayak Paul committed
236
        elif self.use_layer_norm:
Kashif Rasul's avatar
Kashif Rasul committed
237
            norm_hidden_states = self.norm1(hidden_states)
Sayak Paul's avatar
Sayak Paul committed
238
239
240
241
242
243
244
245
246
        elif self.use_ada_layer_norm_single:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
                self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
            ).chunk(6, dim=1)
            norm_hidden_states = self.norm1(hidden_states)
            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
            norm_hidden_states = norm_hidden_states.squeeze(1)
        else:
            raise ValueError("Incorrect norm used")
Kashif Rasul's avatar
Kashif Rasul committed
247

Dhruv Nair's avatar
Dhruv Nair committed
248
249
250
        if self.pos_embed is not None:
            norm_hidden_states = self.pos_embed(norm_hidden_states)

251
252
253
254
        # 1. Retrieve lora scale.
        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0

        # 2. Prepare GLIGEN inputs
255
256
        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
        gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
257

258
259
260
261
262
263
        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
Kashif Rasul's avatar
Kashif Rasul committed
264
265
        if self.use_ada_layer_norm_zero:
            attn_output = gate_msa.unsqueeze(1) * attn_output
Sayak Paul's avatar
Sayak Paul committed
266
267
268
        elif self.use_ada_layer_norm_single:
            attn_output = gate_msa * attn_output

269
        hidden_states = attn_output + hidden_states
Sayak Paul's avatar
Sayak Paul committed
270
271
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)
Will Berman's avatar
Will Berman committed
272

273
        # 2.5 GLIGEN Control
274
275
276
        if gligen_kwargs is not None:
            hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])

277
        # 3. Cross-Attention
278
        if self.attn2 is not None:
Sayak Paul's avatar
Sayak Paul committed
279
280
281
282
283
284
285
286
287
288
289
            if self.use_ada_layer_norm:
                norm_hidden_states = self.norm2(hidden_states, timestep)
            elif self.use_ada_layer_norm_zero or self.use_layer_norm:
                norm_hidden_states = self.norm2(hidden_states)
            elif self.use_ada_layer_norm_single:
                # For PixArt norm2 isn't applied here:
                # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
                norm_hidden_states = hidden_states
            else:
                raise ValueError("Incorrect norm")

290
            if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
Dhruv Nair's avatar
Dhruv Nair committed
291
                norm_hidden_states = self.pos_embed(norm_hidden_states)
Kashif Rasul's avatar
Kashif Rasul committed
292

293
294
295
            attn_output = self.attn2(
                norm_hidden_states,
                encoder_hidden_states=encoder_hidden_states,
296
                attention_mask=encoder_attention_mask,
297
                **cross_attention_kwargs,
Will Berman's avatar
Will Berman committed
298
            )
299
            hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
300

301
        # 4. Feed-forward
Sayak Paul's avatar
Sayak Paul committed
302
303
        if not self.use_ada_layer_norm_single:
            norm_hidden_states = self.norm3(hidden_states)
Kashif Rasul's avatar
Kashif Rasul committed
304
305
306
307

        if self.use_ada_layer_norm_zero:
            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

Sayak Paul's avatar
Sayak Paul committed
308
309
310
311
        if self.use_ada_layer_norm_single:
            norm_hidden_states = self.norm2(hidden_states)
            norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp

312
313
314
315
316
317
318
319
320
        if self._chunk_size is not None:
            # "feed_forward_chunk_size" can be used to save memory
            if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
                raise ValueError(
                    f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
                )

            num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
            ff_output = torch.cat(
321
322
323
324
                [
                    self.ff(hid_slice, scale=lora_scale)
                    for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
                ],
325
326
327
                dim=self._chunk_dim,
            )
        else:
328
            ff_output = self.ff(norm_hidden_states, scale=lora_scale)
Kashif Rasul's avatar
Kashif Rasul committed
329
330
331

        if self.use_ada_layer_norm_zero:
            ff_output = gate_mlp.unsqueeze(1) * ff_output
Sayak Paul's avatar
Sayak Paul committed
332
333
        elif self.use_ada_layer_norm_single:
            ff_output = gate_mlp * ff_output
Kashif Rasul's avatar
Kashif Rasul committed
334
335

        hidden_states = ff_output + hidden_states
Sayak Paul's avatar
Sayak Paul committed
336
337
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)
Will Berman's avatar
Will Berman committed
338

339
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
340
341
342


class FeedForward(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
343
344
345
346
    r"""
    A feed-forward layer.

    Parameters:
Will Berman's avatar
Will Berman committed
347
348
349
350
351
        dim (`int`): The number of channels in the input.
        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
Kashif Rasul's avatar
Kashif Rasul committed
352
        final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
Kashif Rasul's avatar
Kashif Rasul committed
353
354
355
    """

    def __init__(
Will Berman's avatar
Will Berman committed
356
357
358
359
360
361
        self,
        dim: int,
        dim_out: Optional[int] = None,
        mult: int = 4,
        dropout: float = 0.0,
        activation_fn: str = "geglu",
Kashif Rasul's avatar
Kashif Rasul committed
362
        final_dropout: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
363
    ):
Patrick von Platen's avatar
Patrick von Platen committed
364
365
        super().__init__()
        inner_dim = int(dim * mult)
366
        dim_out = dim_out if dim_out is not None else dim
367
        linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
Patrick von Platen's avatar
Patrick von Platen committed
368

369
370
        if activation_fn == "gelu":
            act_fn = GELU(dim, inner_dim)
Kashif Rasul's avatar
Kashif Rasul committed
371
372
        if activation_fn == "gelu-approximate":
            act_fn = GELU(dim, inner_dim, approximate="tanh")
373
374
        elif activation_fn == "geglu":
            act_fn = GEGLU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
375
        elif activation_fn == "geglu-approximate":
376
            act_fn = ApproximateGELU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
377
378

        self.net = nn.ModuleList([])
379
        # project in
380
        self.net.append(act_fn)
381
382
383
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
384
        self.net.append(linear_cls(inner_dim, dim_out))
Kashif Rasul's avatar
Kashif Rasul committed
385
386
387
        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
        if final_dropout:
            self.net.append(nn.Dropout(dropout))
Patrick von Platen's avatar
Patrick von Platen committed
388

389
    def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
390
        compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
391
        for module in self.net:
392
            if isinstance(module, compatible_cls):
393
394
395
                hidden_states = module(hidden_states, scale)
            else:
                hidden_states = module(hidden_states)
396
        return hidden_states