"vscode:/vscode.git/clone" did not exist on "45b4dcf0375ae01223660132869db58ca298eb09"
attention_flax.py 21 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
import functools
import math

18
import flax.linen as nn
19
import jax
20
21
import jax.numpy as jnp

Sayak Paul's avatar
Sayak Paul committed
22
23
24
25
26
from ..utils import logging


logger = logging.get_logger(__name__)

27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
    """Multi-head dot product attention with a limited number of queries."""
    num_kv, num_heads, k_features = key.shape[-3:]
    v_features = value.shape[-1]
    key_chunk_size = min(key_chunk_size, num_kv)
    query = query / jnp.sqrt(k_features)

    @functools.partial(jax.checkpoint, prevent_cse=False)
    def summarize_chunk(query, key, value):
        attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)

        max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
        max_score = jax.lax.stop_gradient(max_score)
        exp_weights = jnp.exp(attn_weights - max_score)

        exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
        max_score = jnp.einsum("...qhk->...qh", max_score)

        return (exp_values, exp_weights.sum(axis=-1), max_score)

    def chunk_scanner(chunk_idx):
        # julienne key array
        key_chunk = jax.lax.dynamic_slice(
            operand=key,
            start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0],  # [...,k,h,d]
            slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features],  # [...,k,h,d]
        )

        # julienne value array
        value_chunk = jax.lax.dynamic_slice(
            operand=value,
            start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0],  # [...,v,h,d]
            slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features],  # [...,v,h,d]
        )

        return summarize_chunk(query, key_chunk, value_chunk)

    chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))

    global_max = jnp.max(chunk_max, axis=0, keepdims=True)
    max_diffs = jnp.exp(chunk_max - global_max)

    chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
    chunk_weights *= max_diffs

    all_values = chunk_values.sum(axis=0)
    all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)

    return all_values / all_weights


def jax_memory_efficient_attention(
    query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):
    r"""
Quentin Gallouédec's avatar
Quentin Gallouédec committed
83
    Flax Memory-efficient multi-head dot product attention. https://huggingface.co/papers/2112.05682v2
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    https://github.com/AminRezaei0x443/memory-efficient-attention

    Args:
        query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
        key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
        value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
        precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
            numerical precision for computation
        query_chunk_size (`int`, *optional*, defaults to 1024):
            chunk size to divide query array value must divide query_length equally without remainder
        key_chunk_size (`int`, *optional*, defaults to 4096):
            chunk size to divide key and value array value must divide key_value_length equally without remainder

    Returns:
        (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
    """
    num_q, num_heads, q_features = query.shape[-3:]

    def chunk_scanner(chunk_idx, _):
        # julienne query array
        query_chunk = jax.lax.dynamic_slice(
            operand=query,
            start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0],  # [...,q,h,d]
            slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features],  # [...,q,h,d]
        )

        return (
            chunk_idx + query_chunk_size,  # unused ignore it
            _query_chunk_attention(
                query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
            ),
        )

    _, res = jax.lax.scan(
118
119
120
121
        f=chunk_scanner,
        init=0,
        xs=None,
        length=math.ceil(num_q / query_chunk_size),  # start counter  # stop counter
122
123
124
125
126
    )

    return jnp.concatenate(res, axis=-3)  # fuse the chunked result back


Patrick von Platen's avatar
Patrick von Platen committed
127
class FlaxAttention(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
128
    r"""
Quentin Gallouédec's avatar
Quentin Gallouédec committed
129
    A Flax multi-head attention module as described in: https://huggingface.co/papers/1706.03762
Younes Belkada's avatar
Younes Belkada committed
130
131
132
133
134
135
136
137
138
139

    Parameters:
        query_dim (:obj:`int`):
            Input hidden states dimension
        heads (:obj:`int`, *optional*, defaults to 8):
            Number of heads
        dim_head (:obj:`int`, *optional*, defaults to 64):
            Hidden states dimension inside each head
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
140
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Quentin Gallouédec's avatar
Quentin Gallouédec committed
141
            enable memory efficient attention https://huggingface.co/papers/2112.05682
Juan Acevedo's avatar
Juan Acevedo committed
142
        split_head_dim (`bool`, *optional*, defaults to `False`):
Pedro Cuenca's avatar
Pedro Cuenca committed
143
144
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
Younes Belkada's avatar
Younes Belkada committed
145
146
147
148
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`

    """
149

150
151
152
153
    query_dim: int
    heads: int = 8
    dim_head: int = 64
    dropout: float = 0.0
154
    use_memory_efficient_attention: bool = False
Juan Acevedo's avatar
Juan Acevedo committed
155
    split_head_dim: bool = False
156
157
158
    dtype: jnp.dtype = jnp.float32

    def setup(self):
Sayak Paul's avatar
Sayak Paul committed
159
160
161
162
163
        logger.warning(
            "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
            "recommend migrating to PyTorch classes or pinning your version of Diffusers."
        )

164
165
166
167
168
169
170
171
        inner_dim = self.dim_head * self.heads
        self.scale = self.dim_head**-0.5

        # Weights were exported with old names {to_q, to_k, to_v, to_out}
        self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
        self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
        self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")

172
        self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
173
        self.dropout_layer = nn.Dropout(rate=self.dropout)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    def reshape_heads_to_batch_dim(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        tensor = jnp.transpose(tensor, (0, 2, 1, 3))
        tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
        return tensor

    def reshape_batch_dim_to_heads(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = jnp.transpose(tensor, (0, 2, 1, 3))
        tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

    def __call__(self, hidden_states, context=None, deterministic=True):
        context = hidden_states if context is None else context

        query_proj = self.query(hidden_states)
        key_proj = self.key(context)
        value_proj = self.value(context)

Juan Acevedo's avatar
Juan Acevedo committed
198
199
200
201
202
203
204
205
206
        if self.split_head_dim:
            b = hidden_states.shape[0]
            query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
            key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
            value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
        else:
            query_states = self.reshape_heads_to_batch_dim(query_proj)
            key_states = self.reshape_heads_to_batch_dim(key_proj)
            value_states = self.reshape_heads_to_batch_dim(value_proj)
207

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        if self.use_memory_efficient_attention:
            query_states = query_states.transpose(1, 0, 2)
            key_states = key_states.transpose(1, 0, 2)
            value_states = value_states.transpose(1, 0, 2)

            # this if statement create a chunk size for each layer of the unet
            # the chunk size is equal to the query_length dimension of the deepest layer of the unet

            flatten_latent_dim = query_states.shape[-3]
            if flatten_latent_dim % 64 == 0:
                query_chunk_size = int(flatten_latent_dim / 64)
            elif flatten_latent_dim % 16 == 0:
                query_chunk_size = int(flatten_latent_dim / 16)
            elif flatten_latent_dim % 4 == 0:
                query_chunk_size = int(flatten_latent_dim / 4)
            else:
                query_chunk_size = int(flatten_latent_dim)

            hidden_states = jax_memory_efficient_attention(
                query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
            )
            hidden_states = hidden_states.transpose(1, 0, 2)
230
            hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
231
232
        else:
            # compute attentions
Juan Acevedo's avatar
Juan Acevedo committed
233
234
235
236
237
            if self.split_head_dim:
                attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
            else:
                attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)

238
            attention_scores = attention_scores * self.scale
Juan Acevedo's avatar
Juan Acevedo committed
239
            attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
240
241

            # attend to values
Juan Acevedo's avatar
Juan Acevedo committed
242
243
244
245
246
247
248
            if self.split_head_dim:
                hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
                b = hidden_states.shape[0]
                hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
            else:
                hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
                hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
249
250

        hidden_states = self.proj_attn(hidden_states)
251
        return self.dropout_layer(hidden_states, deterministic=deterministic)
252
253
254


class FlaxBasicTransformerBlock(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
255
256
    r"""
    A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
Quentin Gallouédec's avatar
Quentin Gallouédec committed
257
    https://huggingface.co/papers/1706.03762
Younes Belkada's avatar
Younes Belkada committed
258
259
260
261
262
263
264
265
266
267
268


    Parameters:
        dim (:obj:`int`):
            Inner hidden states dimension
        n_heads (:obj:`int`):
            Number of heads
        d_head (:obj:`int`):
            Hidden states dimension inside each head
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
269
270
        only_cross_attention (`bool`, defaults to `False`):
            Whether to only apply cross attention.
Younes Belkada's avatar
Younes Belkada committed
271
272
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
273
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Quentin Gallouédec's avatar
Quentin Gallouédec committed
274
            enable memory efficient attention https://huggingface.co/papers/2112.05682
275
276
277
        split_head_dim (`bool`, *optional*, defaults to `False`):
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
Younes Belkada's avatar
Younes Belkada committed
278
    """
279

280
281
282
283
    dim: int
    n_heads: int
    d_head: int
    dropout: float = 0.0
284
    only_cross_attention: bool = False
285
    dtype: jnp.dtype = jnp.float32
286
    use_memory_efficient_attention: bool = False
287
    split_head_dim: bool = False
288
289

    def setup(self):
Sayak Paul's avatar
Sayak Paul committed
290
291
292
293
294
        logger.warning(
            "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
            "recommend migrating to PyTorch classes or pinning your version of Diffusers."
        )

295
        # self attention (or cross_attention if only_cross_attention is True)
296
        self.attn1 = FlaxAttention(
297
298
299
300
301
302
303
            self.dim,
            self.n_heads,
            self.d_head,
            self.dropout,
            self.use_memory_efficient_attention,
            self.split_head_dim,
            dtype=self.dtype,
304
        )
305
        # cross attention
306
        self.attn2 = FlaxAttention(
307
308
309
310
311
312
313
            self.dim,
            self.n_heads,
            self.d_head,
            self.dropout,
            self.use_memory_efficient_attention,
            self.split_head_dim,
            dtype=self.dtype,
314
        )
315
        self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
316
317
318
        self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
        self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
        self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
319
        self.dropout_layer = nn.Dropout(rate=self.dropout)
320
321
322
323

    def __call__(self, hidden_states, context, deterministic=True):
        # self attention
        residual = hidden_states
324
325
326
327
        if self.only_cross_attention:
            hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
        else:
            hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
328
329
330
331
        hidden_states = hidden_states + residual

        # cross attention
        residual = hidden_states
332
        hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
333
334
335
336
337
338
339
        hidden_states = hidden_states + residual

        # feed forward
        residual = hidden_states
        hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
        hidden_states = hidden_states + residual

340
        return self.dropout_layer(hidden_states, deterministic=deterministic)
341
342


Will Berman's avatar
Will Berman committed
343
class FlaxTransformer2DModel(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
344
345
    r"""
    A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
Quentin Gallouédec's avatar
Quentin Gallouédec committed
346
    https://huggingface.co/papers/1506.02025
Younes Belkada's avatar
Younes Belkada committed
347
348
349
350
351
352
353
354
355
356
357
358
359


    Parameters:
        in_channels (:obj:`int`):
            Input number of channels
        n_heads (:obj:`int`):
            Number of heads
        d_head (:obj:`int`):
            Hidden states dimension inside each head
        depth (:obj:`int`, *optional*, defaults to 1):
            Number of transformers block
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
360
361
        use_linear_projection (`bool`, defaults to `False`): tbd
        only_cross_attention (`bool`, defaults to `False`): tbd
Younes Belkada's avatar
Younes Belkada committed
362
363
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
364
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Quentin Gallouédec's avatar
Quentin Gallouédec committed
365
            enable memory efficient attention https://huggingface.co/papers/2112.05682
366
367
368
        split_head_dim (`bool`, *optional*, defaults to `False`):
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
Younes Belkada's avatar
Younes Belkada committed
369
    """
370

371
372
373
374
375
    in_channels: int
    n_heads: int
    d_head: int
    depth: int = 1
    dropout: float = 0.0
376
377
    use_linear_projection: bool = False
    only_cross_attention: bool = False
378
    dtype: jnp.dtype = jnp.float32
379
    use_memory_efficient_attention: bool = False
380
    split_head_dim: bool = False
381
382

    def setup(self):
Sayak Paul's avatar
Sayak Paul committed
383
384
385
386
387
        logger.warning(
            "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
            "recommend migrating to PyTorch classes or pinning your version of Diffusers."
        )

388
389
390
        self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)

        inner_dim = self.n_heads * self.d_head
391
392
393
394
395
396
397
398
399
400
        if self.use_linear_projection:
            self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
        else:
            self.proj_in = nn.Conv(
                inner_dim,
                kernel_size=(1, 1),
                strides=(1, 1),
                padding="VALID",
                dtype=self.dtype,
            )
401
402

        self.transformer_blocks = [
403
404
405
406
407
408
409
            FlaxBasicTransformerBlock(
                inner_dim,
                self.n_heads,
                self.d_head,
                dropout=self.dropout,
                only_cross_attention=self.only_cross_attention,
                dtype=self.dtype,
410
                use_memory_efficient_attention=self.use_memory_efficient_attention,
411
                split_head_dim=self.split_head_dim,
412
            )
413
414
415
            for _ in range(self.depth)
        ]

416
417
418
419
420
421
422
423
424
425
        if self.use_linear_projection:
            self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
        else:
            self.proj_out = nn.Conv(
                inner_dim,
                kernel_size=(1, 1),
                strides=(1, 1),
                padding="VALID",
                dtype=self.dtype,
            )
426

427
428
        self.dropout_layer = nn.Dropout(rate=self.dropout)

429
430
431
432
    def __call__(self, hidden_states, context, deterministic=True):
        batch, height, width, channels = hidden_states.shape
        residual = hidden_states
        hidden_states = self.norm(hidden_states)
433
434
435
436
437
438
        if self.use_linear_projection:
            hidden_states = hidden_states.reshape(batch, height * width, channels)
            hidden_states = self.proj_in(hidden_states)
        else:
            hidden_states = self.proj_in(hidden_states)
            hidden_states = hidden_states.reshape(batch, height * width, channels)
439
440
441
442

        for transformer_block in self.transformer_blocks:
            hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)

443
444
445
446
447
448
        if self.use_linear_projection:
            hidden_states = self.proj_out(hidden_states)
            hidden_states = hidden_states.reshape(batch, height, width, channels)
        else:
            hidden_states = hidden_states.reshape(batch, height, width, channels)
            hidden_states = self.proj_out(hidden_states)
449
450

        hidden_states = hidden_states + residual
451
        return self.dropout_layer(hidden_states, deterministic=deterministic)
452
453


454
class FlaxFeedForward(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
455
    r"""
456
457
458
    Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
    [`FeedForward`] class, with the following simplifications:
    - The activation function is currently hardcoded to a gated linear unit from:
Quentin Gallouédec's avatar
Quentin Gallouédec committed
459
    https://huggingface.co/papers/2002.05202
460
461
    - `dim_out` is equal to `dim`.
    - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
Younes Belkada's avatar
Younes Belkada committed
462
463
464
465
466
467
468
469
470

    Parameters:
        dim (:obj:`int`):
            Inner hidden states dimension
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
471

472
473
474
475
    dim: int
    dropout: float = 0.0
    dtype: jnp.dtype = jnp.float32

476
    def setup(self):
Sayak Paul's avatar
Sayak Paul committed
477
478
479
480
481
        logger.warning(
            "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
            "recommend migrating to PyTorch classes or pinning your version of Diffusers."
        )

482
483
484
485
486
487
        # The second linear layer needs to be called
        # net_2 for now to match the index of the Sequential layer
        self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
        self.net_2 = nn.Dense(self.dim, dtype=self.dtype)

    def __call__(self, hidden_states, deterministic=True):
488
        hidden_states = self.net_0(hidden_states, deterministic=deterministic)
489
490
491
492
493
        hidden_states = self.net_2(hidden_states)
        return hidden_states


class FlaxGEGLU(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
494
495
    r"""
    Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
Quentin Gallouédec's avatar
Quentin Gallouédec committed
496
    https://huggingface.co/papers/2002.05202.
Younes Belkada's avatar
Younes Belkada committed
497
498
499
500
501
502
503
504
505

    Parameters:
        dim (:obj:`int`):
            Input hidden states dimension
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
506

507
508
509
510
    dim: int
    dropout: float = 0.0
    dtype: jnp.dtype = jnp.float32

511
    def setup(self):
Sayak Paul's avatar
Sayak Paul committed
512
513
514
515
516
        logger.warning(
            "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
            "recommend migrating to PyTorch classes or pinning your version of Diffusers."
        )

517
        inner_dim = self.dim * 4
518
        self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
519
        self.dropout_layer = nn.Dropout(rate=self.dropout)
520
521

    def __call__(self, hidden_states, deterministic=True):
522
        hidden_states = self.proj(hidden_states)
523
        hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
524
        return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)