Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e4a9fb3b
Unverified
Commit
e4a9fb3b
authored
Mar 01, 2023
by
Pedro Cuenca
Committed by
GitHub
Mar 01, 2023
Browse files
Bring Flax attention naming in sync with PyTorch (#2511)
Bring flax attention naming in sync with PyTorch.
parent
eadf0e25
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
6 deletions
+10
-6
src/diffusers/models/attention_flax.py
src/diffusers/models/attention_flax.py
+10
-6
No files found.
src/diffusers/models/attention_flax.py
View file @
e4a9fb3b
...
...
@@ -16,7 +16,7 @@ import flax.linen as nn
import
jax.numpy
as
jnp
class
FlaxAttention
Block
(
nn
.
Module
):
class
Flax
Cross
Attention
(
nn
.
Module
):
r
"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
...
...
@@ -118,10 +118,10 @@ class FlaxBasicTransformerBlock(nn.Module):
def
setup
(
self
):
# self attention (or cross_attention if only_cross_attention is True)
self
.
attn1
=
FlaxAttention
Block
(
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
dtype
=
self
.
dtype
)
self
.
attn1
=
Flax
Cross
Attention
(
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
dtype
=
self
.
dtype
)
# cross attention
self
.
attn2
=
FlaxAttention
Block
(
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
dtype
=
self
.
dtype
)
self
.
ff
=
Flax
Glu
FeedForward
(
dim
=
self
.
dim
,
dropout
=
self
.
dropout
,
dtype
=
self
.
dtype
)
self
.
attn2
=
Flax
Cross
Attention
(
self
.
dim
,
self
.
n_heads
,
self
.
d_head
,
self
.
dropout
,
dtype
=
self
.
dtype
)
self
.
ff
=
FlaxFeedForward
(
dim
=
self
.
dim
,
dropout
=
self
.
dropout
,
dtype
=
self
.
dtype
)
self
.
norm1
=
nn
.
LayerNorm
(
epsilon
=
1e-5
,
dtype
=
self
.
dtype
)
self
.
norm2
=
nn
.
LayerNorm
(
epsilon
=
1e-5
,
dtype
=
self
.
dtype
)
self
.
norm3
=
nn
.
LayerNorm
(
epsilon
=
1e-5
,
dtype
=
self
.
dtype
)
...
...
@@ -242,10 +242,14 @@ class FlaxTransformer2DModel(nn.Module):
return
hidden_states
class
Flax
Glu
FeedForward
(
nn
.
Module
):
class
FlaxFeedForward
(
nn
.
Module
):
r
"""
Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
[`FeedForward`] class, with the following simplifications:
- The activation function is currently hardcoded to a gated linear unit from:
https://arxiv.org/abs/2002.05202
- `dim_out` is equal to `dim`.
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
Parameters:
dim (:obj:`int`):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment