attention.py 21 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import math
15
from typing import Callable, Optional
16
17

import torch
Patrick von Platen's avatar
Patrick von Platen committed
18
import torch.nn.functional as F
19
20
from torch import nn

21
from ..utils import maybe_allow_in_graph
Will Berman's avatar
Will Berman committed
22
from ..utils.import_utils import is_xformers_available
Patrick von Platen's avatar
Patrick von Platen committed
23
from .attention_processor import Attention
Kashif Rasul's avatar
Kashif Rasul committed
24
from .embeddings import CombinedTimestepLabelEmbeddings
25
26
27
28
29
30
31
32


if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None

33

34
class AttentionBlock(nn.Module):
Patrick von Platen's avatar
Patrick von Platen committed
35
36
37
38
    """
    An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
    to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Kashif Rasul's avatar
Kashif Rasul committed
39
40
41
    Uses three q, k, v linear layers to compute attention.

    Parameters:
Will Berman's avatar
Will Berman committed
42
43
        channels (`int`): The number of channels in the input and output.
        num_head_channels (`int`, *optional*):
Kashif Rasul's avatar
Kashif Rasul committed
44
            The number of channels in each head. If None, then `num_heads` = 1.
Will Berman's avatar
Will Berman committed
45
46
47
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
        rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
        eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
Patrick von Platen's avatar
Patrick von Platen committed
48
49
    """

Will Berman's avatar
Will Berman committed
50
51
    # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore

Patrick von Platen's avatar
Patrick von Platen committed
52
53
    def __init__(
        self,
Kashif Rasul's avatar
Kashif Rasul committed
54
55
        channels: int,
        num_head_channels: Optional[int] = None,
Will Berman's avatar
Will Berman committed
56
        norm_num_groups: int = 32,
Kashif Rasul's avatar
Kashif Rasul committed
57
58
        rescale_output_factor: float = 1.0,
        eps: float = 1e-5,
Patrick von Platen's avatar
Patrick von Platen committed
59
60
61
62
    ):
        super().__init__()
        self.channels = channels

Patrick von Platen's avatar
Patrick von Platen committed
63
        self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
Will Berman's avatar
Will Berman committed
64
        self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
Patrick von Platen's avatar
Patrick von Platen committed
65
66
67
68
69
70
71

        # define q,k,v as linear layers
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)

        self.rescale_output_factor = rescale_output_factor
Alexander Pivovarov's avatar
Alexander Pivovarov committed
72
        self.proj_attn = nn.Linear(channels, channels, bias=True)
Patrick von Platen's avatar
Patrick von Platen committed
73

74
        self._use_memory_efficient_attention_xformers = False
75
        self._use_2_0_attn = True
76
        self._attention_op = None
77

78
    def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
79
80
81
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
82
83
84
        tensor = tensor.permute(0, 2, 1, 3)
        if merge_head_and_batch:
            tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
85
86
        return tensor

87
    def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
88
        head_size = self.num_heads
89
90

        if unmerge_head_and_batch:
91
92
93
94
            batch_head_size, seq_len, dim = tensor.shape
            batch_size = batch_head_size // head_size

            tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
95
96
97
98
        else:
            batch_size, _, seq_len, dim = tensor.shape

        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
99
100
        return tensor

101
102
103
    def set_use_memory_efficient_attention_xformers(
        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
    ):
104
105
106
        if use_memory_efficient_attention_xformers:
            if not is_xformers_available():
                raise ModuleNotFoundError(
Patrick von Platen's avatar
Patrick von Platen committed
107
108
109
110
                    (
                        "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
                        " xformers"
                    ),
111
112
113
114
                    name="xformers",
                )
            elif not torch.cuda.is_available():
                raise ValueError(
Patrick von Platen's avatar
Patrick von Platen committed
115
116
                    "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
                    " only available for GPU "
117
                )
118
119
120
121
122
123
124
125
126
127
128
            else:
                try:
                    # Make sure we can run the memory efficient attention
                    _ = xformers.ops.memory_efficient_attention(
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                    )
                except Exception as e:
                    raise e
        self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
129
        self._attention_op = attention_op
130

Patrick von Platen's avatar
Patrick von Platen committed
131
132
133
134
135
136
    def forward(self, hidden_states):
        residual = hidden_states
        batch, channel, height, width = hidden_states.shape

        # norm
        hidden_states = self.group_norm(hidden_states)
137

Patrick von Platen's avatar
Patrick von Platen committed
138
139
140
141
142
143
144
        hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

        # proj to q, k, v
        query_proj = self.query(hidden_states)
        key_proj = self.key(hidden_states)
        value_proj = self.value(hidden_states)

145
        scale = 1 / math.sqrt(self.channels / self.num_heads)
Patrick von Platen's avatar
Patrick von Platen committed
146

147
148
        _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
        use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
149
150
151
152

        query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
        key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
        value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
Suraj Patil's avatar
Suraj Patil committed
153

154
155
        if self._use_memory_efficient_attention_xformers:
            # Memory efficient attention
156
            hidden_states = xformers.ops.memory_efficient_attention(
157
158
159
160
161
162
163
164
                query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
            )
            hidden_states = hidden_states.to(query_proj.dtype)
        elif use_torch_2_0_attn:
            # the output of sdp = (batch, num_heads, seq_len, head_dim)
            # TODO: add support for attn.scale when we move to Torch 2.1
            hidden_states = F.scaled_dot_product_attention(
                query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
165
            )
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            hidden_states = hidden_states.to(query_proj.dtype)
        else:
            attention_scores = torch.baddbmm(
                torch.empty(
                    query_proj.shape[0],
                    query_proj.shape[1],
                    key_proj.shape[1],
                    dtype=query_proj.dtype,
                    device=query_proj.device,
                ),
                query_proj,
                key_proj.transpose(-1, -2),
                beta=0,
                alpha=scale,
            )
            attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
            hidden_states = torch.bmm(attention_probs, value_proj)
Patrick von Platen's avatar
Patrick von Platen committed
183

Suraj Patil's avatar
Suraj Patil committed
184
        # reshape hidden_states
185
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
Patrick von Platen's avatar
Patrick von Platen committed
186
187

        # compute next hidden_states
188
        hidden_states = self.proj_attn(hidden_states)
Will Berman's avatar
Will Berman committed
189

Patrick von Platen's avatar
Patrick von Platen committed
190
191
192
193
194
195
        hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

        # res connect and rescale
        hidden_states = (hidden_states + residual) / self.rescale_output_factor
        return hidden_states

Patrick von Platen's avatar
Patrick von Platen committed
196

197
@maybe_allow_in_graph
Patrick von Platen's avatar
Patrick von Platen committed
198
class BasicTransformerBlock(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
199
200
201
202
    r"""
    A basic Transformer block.

    Parameters:
Will Berman's avatar
Will Berman committed
203
204
205
206
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
Will Berman's avatar
Will Berman committed
207
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
208
209
210
211
        only_cross_attention (`bool`, *optional*):
            Whether to use only cross-attention layers. In this case two cross attention layers are used.
        double_self_attention (`bool`, *optional*):
            Whether to use two self-attention layers. In this case no cross attention layers are used.
Will Berman's avatar
Will Berman committed
212
213
214
215
216
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm (:
            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:
            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
Kashif Rasul's avatar
Kashif Rasul committed
217
218
219
220
221
    """

    def __init__(
        self,
        dim: int,
Will Berman's avatar
Will Berman committed
222
223
        num_attention_heads: int,
        attention_head_dim: int,
Kashif Rasul's avatar
Kashif Rasul committed
224
        dropout=0.0,
Will Berman's avatar
Will Berman committed
225
226
227
228
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
229
        only_cross_attention: bool = False,
230
        double_self_attention: bool = False,
231
        upcast_attention: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
232
233
234
        norm_elementwise_affine: bool = True,
        norm_type: str = "layer_norm",
        final_dropout: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
235
    ):
Patrick von Platen's avatar
Patrick von Platen committed
236
        super().__init__()
237
        self.only_cross_attention = only_cross_attention
Kashif Rasul's avatar
Kashif Rasul committed
238
239
240
241
242
243
244
245
246

        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"

        if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
            raise ValueError(
                f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
                f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
            )
247

248
        # Define 3 blocks. Each block has its own normalization layer.
249
        # 1. Self-Attn
250
251
252
253
254
255
        if self.use_ada_layer_norm:
            self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
        elif self.use_ada_layer_norm_zero:
            self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
        else:
            self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
Patrick von Platen's avatar
Patrick von Platen committed
256
        self.attn1 = Attention(
Will Berman's avatar
Will Berman committed
257
258
259
260
261
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
262
            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
263
            upcast_attention=upcast_attention,
264
265
        )

266
        # 2. Cross-Attn
267
        if cross_attention_dim is not None or double_self_attention:
268
269
270
271
272
273
274
275
            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
            # the second cross attention block.
            self.norm2 = (
                AdaLayerNorm(dim, num_embeds_ada_norm)
                if self.use_ada_layer_norm
                else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
            )
Patrick von Platen's avatar
Patrick von Platen committed
276
            self.attn2 = Attention(
277
                query_dim=dim,
278
                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
279
280
281
282
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
283
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
284
            )  # is self-attn if encoder_hidden_states is none
285
286
        else:
            self.norm2 = None
287
            self.attn2 = None
288
289

        # 3. Feed-forward
Kashif Rasul's avatar
Kashif Rasul committed
290
        self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
291
        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
Patrick von Platen's avatar
Patrick von Platen committed
292

293
294
295
    def forward(
        self,
        hidden_states,
296
        attention_mask=None,
297
        encoder_hidden_states=None,
298
        encoder_attention_mask=None,
299
300
        timestep=None,
        cross_attention_kwargs=None,
Kashif Rasul's avatar
Kashif Rasul committed
301
        class_labels=None,
302
    ):
303
304
        # Notice that normalization is always applied before the real computation in the following blocks.
        # 1. Self-Attention
Kashif Rasul's avatar
Kashif Rasul committed
305
306
307
308
309
310
311
312
313
        if self.use_ada_layer_norm:
            norm_hidden_states = self.norm1(hidden_states, timestep)
        elif self.use_ada_layer_norm_zero:
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
            )
        else:
            norm_hidden_states = self.norm1(hidden_states)

314
315
316
317
318
319
320
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
Kashif Rasul's avatar
Kashif Rasul committed
321
322
        if self.use_ada_layer_norm_zero:
            attn_output = gate_msa.unsqueeze(1) * attn_output
323
        hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
324

325
        # 2. Cross-Attention
326
327
328
329
        if self.attn2 is not None:
            norm_hidden_states = (
                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
            )
330
331
            # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
            # prepare attention mask here
Kashif Rasul's avatar
Kashif Rasul committed
332

333
334
335
            attn_output = self.attn2(
                norm_hidden_states,
                encoder_hidden_states=encoder_hidden_states,
336
                attention_mask=encoder_attention_mask,
337
                **cross_attention_kwargs,
Will Berman's avatar
Will Berman committed
338
            )
339
            hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
340
341

        # 3. Feed-forward
Kashif Rasul's avatar
Kashif Rasul committed
342
343
344
345
346
347
348
349
350
351
352
        norm_hidden_states = self.norm3(hidden_states)

        if self.use_ada_layer_norm_zero:
            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

        ff_output = self.ff(norm_hidden_states)

        if self.use_ada_layer_norm_zero:
            ff_output = gate_mlp.unsqueeze(1) * ff_output

        hidden_states = ff_output + hidden_states
Will Berman's avatar
Will Berman committed
353

354
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
355
356
357


class FeedForward(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
358
359
360
361
    r"""
    A feed-forward layer.

    Parameters:
Will Berman's avatar
Will Berman committed
362
363
364
365
366
        dim (`int`): The number of channels in the input.
        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
Kashif Rasul's avatar
Kashif Rasul committed
367
        final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
Kashif Rasul's avatar
Kashif Rasul committed
368
369
370
    """

    def __init__(
Will Berman's avatar
Will Berman committed
371
372
373
374
375
376
        self,
        dim: int,
        dim_out: Optional[int] = None,
        mult: int = 4,
        dropout: float = 0.0,
        activation_fn: str = "geglu",
Kashif Rasul's avatar
Kashif Rasul committed
377
        final_dropout: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
378
    ):
Patrick von Platen's avatar
Patrick von Platen committed
379
380
        super().__init__()
        inner_dim = int(dim * mult)
381
        dim_out = dim_out if dim_out is not None else dim
Patrick von Platen's avatar
Patrick von Platen committed
382

383
384
        if activation_fn == "gelu":
            act_fn = GELU(dim, inner_dim)
Kashif Rasul's avatar
Kashif Rasul committed
385
386
        if activation_fn == "gelu-approximate":
            act_fn = GELU(dim, inner_dim, approximate="tanh")
387
388
        elif activation_fn == "geglu":
            act_fn = GEGLU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
389
        elif activation_fn == "geglu-approximate":
390
            act_fn = ApproximateGELU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
391
392

        self.net = nn.ModuleList([])
393
        # project in
394
        self.net.append(act_fn)
395
396
397
398
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
        self.net.append(nn.Linear(inner_dim, dim_out))
Kashif Rasul's avatar
Kashif Rasul committed
399
400
401
        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
        if final_dropout:
            self.net.append(nn.Dropout(dropout))
Patrick von Platen's avatar
Patrick von Platen committed
402

403
    def forward(self, hidden_states):
404
405
406
        for module in self.net:
            hidden_states = module(hidden_states)
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
407

Patrick von Platen's avatar
Patrick von Platen committed
408

409
410
class GELU(nn.Module):
    r"""
Kashif Rasul's avatar
Kashif Rasul committed
411
    GELU activation function with tanh approximation support with `approximate="tanh"`.
412
413
    """

Kashif Rasul's avatar
Kashif Rasul committed
414
    def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
415
416
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)
Kashif Rasul's avatar
Kashif Rasul committed
417
        self.approximate = approximate
418
419
420

    def gelu(self, gate):
        if gate.device.type != "mps":
Kashif Rasul's avatar
Kashif Rasul committed
421
            return F.gelu(gate, approximate=self.approximate)
422
        # mps: gelu is not implemented for float16
Kashif Rasul's avatar
Kashif Rasul committed
423
        return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
424
425
426
427
428
429
430

    def forward(self, hidden_states):
        hidden_states = self.proj(hidden_states)
        hidden_states = self.gelu(hidden_states)
        return hidden_states


Patrick von Platen's avatar
Patrick von Platen committed
431
class GEGLU(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
432
433
434
435
    r"""
    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.

    Parameters:
Will Berman's avatar
Will Berman committed
436
437
        dim_in (`int`): The number of channels in the input.
        dim_out (`int`): The number of channels in the output.
Kashif Rasul's avatar
Kashif Rasul committed
438
439
440
    """

    def __init__(self, dim_in: int, dim_out: int):
Patrick von Platen's avatar
Patrick von Platen committed
441
442
443
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

444
445
446
447
448
449
    def gelu(self, gate):
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

450
451
    def forward(self, hidden_states):
        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
452
        return hidden_states * self.gelu(gate)
Will Berman's avatar
Will Berman committed
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487


class ApproximateGELU(nn.Module):
    """
    The approximate form of Gaussian Error Linear Unit (GELU)

    For more details, see section 2: https://arxiv.org/abs/1606.08415
    """

    def __init__(self, dim_in: int, dim_out: int):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        x = self.proj(x)
        return x * torch.sigmoid(1.702 * x)


class AdaLayerNorm(nn.Module):
    """
    Norm layer modified to incorporate timestep embeddings.
    """

    def __init__(self, embedding_dim, num_embeddings):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings, embedding_dim)
        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)

    def forward(self, x, timestep):
        emb = self.linear(self.silu(self.emb(timestep)))
        scale, shift = torch.chunk(emb, 2)
        x = self.norm(x) * (1 + scale) + shift
        return x
Kashif Rasul's avatar
Kashif Rasul committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508


class AdaLayerNormZero(nn.Module):
    """
    Norm layer adaptive layer norm zero (adaLN-Zero).
    """

    def __init__(self, embedding_dim, num_embeddings):
        super().__init__()

        self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)

        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, timestep, class_labels, hidden_dtype=None):
        emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543


class AdaGroupNorm(nn.Module):
    """
    GroupNorm layer modified to incorporate timestep embeddings.
    """

    def __init__(
        self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
    ):
        super().__init__()
        self.num_groups = num_groups
        self.eps = eps
        self.act = None
        if act_fn == "swish":
            self.act = lambda x: F.silu(x)
        elif act_fn == "mish":
            self.act = nn.Mish()
        elif act_fn == "silu":
            self.act = nn.SiLU()
        elif act_fn == "gelu":
            self.act = nn.GELU()

        self.linear = nn.Linear(embedding_dim, out_dim * 2)

    def forward(self, x, emb):
        if self.act:
            emb = self.act(emb)
        emb = self.linear(emb)
        emb = emb[:, :, None, None]
        scale, shift = emb.chunk(2, dim=1)

        x = F.group_norm(x, self.num_groups, eps=self.eps)
        x = x * (1 + scale) + shift
        return x