Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
50e66f2f
Unverified
Commit
50e66f2f
authored
Jul 26, 2024
by
Sayak Paul
Committed by
GitHub
Jul 26, 2024
Browse files
[Chore] remove all is from auraflow. (#8980)
remove all is from auraflow.
parent
9b8c8605
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
src/diffusers/models/transformers/auraflow_transformer_2d.py
src/diffusers/models/transformers/auraflow_transformer_2d.py
+5
-5
No files found.
src/diffusers/models/transformers/auraflow_transformer_2d.py
View file @
50e66f2f
...
...
@@ -138,14 +138,14 @@ class AuraFlowSingleTransformerBlock(nn.Module):
self
.
norm2
=
FP32LayerNorm
(
dim
,
elementwise_affine
=
False
,
bias
=
False
)
self
.
ff
=
AuraFlowFeedForward
(
dim
,
dim
*
4
)
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
temb
:
torch
.
FloatTensor
,
i
=
9999
):
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
temb
:
torch
.
FloatTensor
):
residual
=
hidden_states
# Norm + Projection.
norm_hidden_states
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
norm1
(
hidden_states
,
emb
=
temb
)
# Attention.
attn_output
=
self
.
attn
(
hidden_states
=
norm_hidden_states
,
i
=
i
)
attn_output
=
self
.
attn
(
hidden_states
=
norm_hidden_states
)
# Process attention outputs for the `hidden_states`.
hidden_states
=
self
.
norm2
(
residual
+
gate_msa
.
unsqueeze
(
1
)
*
attn_output
)
...
...
@@ -201,7 +201,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
self
.
ff_context
=
AuraFlowFeedForward
(
dim
,
dim
*
4
)
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
torch
.
FloatTensor
,
temb
:
torch
.
FloatTensor
,
i
=
0
self
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
torch
.
FloatTensor
,
temb
:
torch
.
FloatTensor
):
residual
=
hidden_states
residual_context
=
encoder_hidden_states
...
...
@@ -214,7 +214,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
# Attention.
attn_output
,
context_attn_output
=
self
.
attn
(
hidden_states
=
norm_hidden_states
,
encoder_hidden_states
=
norm_encoder_hidden_states
,
i
=
i
hidden_states
=
norm_hidden_states
,
encoder_hidden_states
=
norm_encoder_hidden_states
)
# Process attention outputs for the `hidden_states`.
...
...
@@ -366,7 +366,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
else
:
encoder_hidden_states
,
hidden_states
=
block
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
,
i
=
index_block
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment