attention.py 14.8 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from typing import Any, Dict, Optional
15
16

import torch
Patrick von Platen's avatar
Patrick von Platen committed
17
import torch.nn.functional as F
18
19
from torch import nn

20
from ..utils import maybe_allow_in_graph
21
from .activations import get_activation
Patrick von Platen's avatar
Patrick von Platen committed
22
from .attention_processor import Attention
Kashif Rasul's avatar
Kashif Rasul committed
23
from .embeddings import CombinedTimestepLabelEmbeddings
24
25


26
@maybe_allow_in_graph
Patrick von Platen's avatar
Patrick von Platen committed
27
class BasicTransformerBlock(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
28
29
30
31
    r"""
    A basic Transformer block.

    Parameters:
Will Berman's avatar
Will Berman committed
32
33
34
35
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
Will Berman's avatar
Will Berman committed
36
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
37
38
39
40
        only_cross_attention (`bool`, *optional*):
            Whether to use only cross-attention layers. In this case two cross attention layers are used.
        double_self_attention (`bool`, *optional*):
            Whether to use two self-attention layers. In this case no cross attention layers are used.
Will Berman's avatar
Will Berman committed
41
42
43
44
45
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm (:
            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:
            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
Kashif Rasul's avatar
Kashif Rasul committed
46
47
48
49
50
    """

    def __init__(
        self,
        dim: int,
Will Berman's avatar
Will Berman committed
51
52
        num_attention_heads: int,
        attention_head_dim: int,
Kashif Rasul's avatar
Kashif Rasul committed
53
        dropout=0.0,
Will Berman's avatar
Will Berman committed
54
55
56
57
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
58
        only_cross_attention: bool = False,
59
        double_self_attention: bool = False,
60
        upcast_attention: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
61
62
63
        norm_elementwise_affine: bool = True,
        norm_type: str = "layer_norm",
        final_dropout: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
64
    ):
Patrick von Platen's avatar
Patrick von Platen committed
65
        super().__init__()
66
        self.only_cross_attention = only_cross_attention
Kashif Rasul's avatar
Kashif Rasul committed
67
68
69
70
71
72
73
74
75

        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"

        if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
            raise ValueError(
                f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
                f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
            )
76

77
        # Define 3 blocks. Each block has its own normalization layer.
78
        # 1. Self-Attn
79
80
81
82
83
84
        if self.use_ada_layer_norm:
            self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
        elif self.use_ada_layer_norm_zero:
            self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
        else:
            self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
Patrick von Platen's avatar
Patrick von Platen committed
85
        self.attn1 = Attention(
Will Berman's avatar
Will Berman committed
86
87
88
89
90
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
91
            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
92
            upcast_attention=upcast_attention,
93
94
        )

95
        # 2. Cross-Attn
96
        if cross_attention_dim is not None or double_self_attention:
97
98
99
100
101
102
103
104
            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
            # the second cross attention block.
            self.norm2 = (
                AdaLayerNorm(dim, num_embeds_ada_norm)
                if self.use_ada_layer_norm
                else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
            )
Patrick von Platen's avatar
Patrick von Platen committed
105
            self.attn2 = Attention(
106
                query_dim=dim,
107
                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
108
109
110
111
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
112
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
113
            )  # is self-attn if encoder_hidden_states is none
114
115
        else:
            self.norm2 = None
116
            self.attn2 = None
117
118

        # 3. Feed-forward
Kashif Rasul's avatar
Kashif Rasul committed
119
        self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
120
        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
Patrick von Platen's avatar
Patrick von Platen committed
121

122
123
124
125
126
127
128
129
130
        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = 0

    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        self._chunk_dim = dim

131
132
    def forward(
        self,
133
134
135
136
137
138
139
        hidden_states: torch.FloatTensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        timestep: Optional[torch.LongTensor] = None,
        cross_attention_kwargs: Dict[str, Any] = None,
        class_labels: Optional[torch.LongTensor] = None,
140
    ):
141
142
        # Notice that normalization is always applied before the real computation in the following blocks.
        # 1. Self-Attention
Kashif Rasul's avatar
Kashif Rasul committed
143
144
145
146
147
148
149
150
151
        if self.use_ada_layer_norm:
            norm_hidden_states = self.norm1(hidden_states, timestep)
        elif self.use_ada_layer_norm_zero:
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
            )
        else:
            norm_hidden_states = self.norm1(hidden_states)

152
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
153

154
155
156
157
158
159
        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
Kashif Rasul's avatar
Kashif Rasul committed
160
161
        if self.use_ada_layer_norm_zero:
            attn_output = gate_msa.unsqueeze(1) * attn_output
162
        hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
163

164
        # 2. Cross-Attention
165
166
167
168
        if self.attn2 is not None:
            norm_hidden_states = (
                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
            )
Kashif Rasul's avatar
Kashif Rasul committed
169

170
171
172
            attn_output = self.attn2(
                norm_hidden_states,
                encoder_hidden_states=encoder_hidden_states,
173
                attention_mask=encoder_attention_mask,
174
                **cross_attention_kwargs,
Will Berman's avatar
Will Berman committed
175
            )
176
            hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
177
178

        # 3. Feed-forward
Kashif Rasul's avatar
Kashif Rasul committed
179
180
181
182
183
        norm_hidden_states = self.norm3(hidden_states)

        if self.use_ada_layer_norm_zero:
            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

184
185
186
187
188
189
190
191
192
193
194
195
196
197
        if self._chunk_size is not None:
            # "feed_forward_chunk_size" can be used to save memory
            if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
                raise ValueError(
                    f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
                )

            num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
            ff_output = torch.cat(
                [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
                dim=self._chunk_dim,
            )
        else:
            ff_output = self.ff(norm_hidden_states)
Kashif Rasul's avatar
Kashif Rasul committed
198
199
200
201
202

        if self.use_ada_layer_norm_zero:
            ff_output = gate_mlp.unsqueeze(1) * ff_output

        hidden_states = ff_output + hidden_states
Will Berman's avatar
Will Berman committed
203

204
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
205
206
207


class FeedForward(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
208
209
210
211
    r"""
    A feed-forward layer.

    Parameters:
Will Berman's avatar
Will Berman committed
212
213
214
215
216
        dim (`int`): The number of channels in the input.
        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
Kashif Rasul's avatar
Kashif Rasul committed
217
        final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
Kashif Rasul's avatar
Kashif Rasul committed
218
219
220
    """

    def __init__(
Will Berman's avatar
Will Berman committed
221
222
223
224
225
226
        self,
        dim: int,
        dim_out: Optional[int] = None,
        mult: int = 4,
        dropout: float = 0.0,
        activation_fn: str = "geglu",
Kashif Rasul's avatar
Kashif Rasul committed
227
        final_dropout: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
228
    ):
Patrick von Platen's avatar
Patrick von Platen committed
229
230
        super().__init__()
        inner_dim = int(dim * mult)
231
        dim_out = dim_out if dim_out is not None else dim
Patrick von Platen's avatar
Patrick von Platen committed
232

233
234
        if activation_fn == "gelu":
            act_fn = GELU(dim, inner_dim)
Kashif Rasul's avatar
Kashif Rasul committed
235
236
        if activation_fn == "gelu-approximate":
            act_fn = GELU(dim, inner_dim, approximate="tanh")
237
238
        elif activation_fn == "geglu":
            act_fn = GEGLU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
239
        elif activation_fn == "geglu-approximate":
240
            act_fn = ApproximateGELU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
241
242

        self.net = nn.ModuleList([])
243
        # project in
244
        self.net.append(act_fn)
245
246
247
248
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
        self.net.append(nn.Linear(inner_dim, dim_out))
Kashif Rasul's avatar
Kashif Rasul committed
249
250
251
        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
        if final_dropout:
            self.net.append(nn.Dropout(dropout))
Patrick von Platen's avatar
Patrick von Platen committed
252

253
    def forward(self, hidden_states):
254
255
256
        for module in self.net:
            hidden_states = module(hidden_states)
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
257

Patrick von Platen's avatar
Patrick von Platen committed
258

259
260
class GELU(nn.Module):
    r"""
Kashif Rasul's avatar
Kashif Rasul committed
261
    GELU activation function with tanh approximation support with `approximate="tanh"`.
262
263
    """

Kashif Rasul's avatar
Kashif Rasul committed
264
    def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
265
266
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)
Kashif Rasul's avatar
Kashif Rasul committed
267
        self.approximate = approximate
268
269
270

    def gelu(self, gate):
        if gate.device.type != "mps":
Kashif Rasul's avatar
Kashif Rasul committed
271
            return F.gelu(gate, approximate=self.approximate)
272
        # mps: gelu is not implemented for float16
Kashif Rasul's avatar
Kashif Rasul committed
273
        return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
274
275
276
277
278
279
280

    def forward(self, hidden_states):
        hidden_states = self.proj(hidden_states)
        hidden_states = self.gelu(hidden_states)
        return hidden_states


Patrick von Platen's avatar
Patrick von Platen committed
281
class GEGLU(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
282
283
284
285
    r"""
    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.

    Parameters:
Will Berman's avatar
Will Berman committed
286
287
        dim_in (`int`): The number of channels in the input.
        dim_out (`int`): The number of channels in the output.
Kashif Rasul's avatar
Kashif Rasul committed
288
289
290
    """

    def __init__(self, dim_in: int, dim_out: int):
Patrick von Platen's avatar
Patrick von Platen committed
291
292
293
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

294
295
296
297
298
299
    def gelu(self, gate):
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

300
301
    def forward(self, hidden_states):
        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
302
        return hidden_states * self.gelu(gate)
Will Berman's avatar
Will Berman committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337


class ApproximateGELU(nn.Module):
    """
    The approximate form of Gaussian Error Linear Unit (GELU)

    For more details, see section 2: https://arxiv.org/abs/1606.08415
    """

    def __init__(self, dim_in: int, dim_out: int):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        x = self.proj(x)
        return x * torch.sigmoid(1.702 * x)


class AdaLayerNorm(nn.Module):
    """
    Norm layer modified to incorporate timestep embeddings.
    """

    def __init__(self, embedding_dim, num_embeddings):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings, embedding_dim)
        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)

    def forward(self, x, timestep):
        emb = self.linear(self.silu(self.emb(timestep)))
        scale, shift = torch.chunk(emb, 2)
        x = self.norm(x) * (1 + scale) + shift
        return x
Kashif Rasul's avatar
Kashif Rasul committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358


class AdaLayerNormZero(nn.Module):
    """
    Norm layer adaptive layer norm zero (adaLN-Zero).
    """

    def __init__(self, embedding_dim, num_embeddings):
        super().__init__()

        self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)

        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, timestep, class_labels, hidden_dtype=None):
        emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
359
360
361
362
363
364
365
366
367
368
369
370
371


class AdaGroupNorm(nn.Module):
    """
    GroupNorm layer modified to incorporate timestep embeddings.
    """

    def __init__(
        self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
    ):
        super().__init__()
        self.num_groups = num_groups
        self.eps = eps
372
373
374
375
376

        if act_fn is None:
            self.act = None
        else:
            self.act = get_activation(act_fn)
377
378
379
380
381
382
383
384
385
386
387
388
389

        self.linear = nn.Linear(embedding_dim, out_dim * 2)

    def forward(self, x, emb):
        if self.act:
            emb = self.act(emb)
        emb = self.linear(emb)
        emb = emb[:, :, None, None]
        scale, shift = emb.chunk(2, dim=1)

        x = F.group_norm(x, self.num_groups, eps=self.eps)
        x = x * (1 + scale) + shift
        return x