Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
888468dd
Unverified
Commit
888468dd
authored
Oct 31, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 31, 2022
Browse files
Remove nn sequential (#1086)
* Remove nn sequential * up
parent
17c2c060
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
5 deletions
+18
-5
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+18
-5
No files found.
src/diffusers/models/attention.py
View file @
888468dd
...
@@ -244,7 +244,9 @@ class CrossAttention(nn.Module):
...
@@ -244,7 +244,9 @@ class CrossAttention(nn.Module):
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
ModuleList
([])
self
.
to_out
.
append
(
nn
.
Linear
(
inner_dim
,
query_dim
))
self
.
to_out
.
append
(
nn
.
Dropout
(
dropout
))
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
batch_size
,
seq_len
,
dim
=
tensor
.
shape
...
@@ -283,7 +285,11 @@ class CrossAttention(nn.Module):
...
@@ -283,7 +285,11 @@ class CrossAttention(nn.Module):
else
:
else
:
hidden_states
=
self
.
_sliced_attention
(
query
,
key
,
value
,
sequence_length
,
dim
)
hidden_states
=
self
.
_sliced_attention
(
query
,
key
,
value
,
sequence_length
,
dim
)
return
self
.
to_out
(
hidden_states
)
# linear proj
hidden_states
=
self
.
to_out
[
0
](
hidden_states
)
# dropout
hidden_states
=
self
.
to_out
[
1
](
hidden_states
)
return
hidden_states
def
_attention
(
self
,
query
,
key
,
value
):
def
_attention
(
self
,
query
,
key
,
value
):
# TODO: use baddbmm for better performance
# TODO: use baddbmm for better performance
...
@@ -354,12 +360,19 @@ class FeedForward(nn.Module):
...
@@ -354,12 +360,19 @@ class FeedForward(nn.Module):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
dim_out
if
dim_out
is
not
None
else
dim
dim_out
=
dim_out
if
dim_out
is
not
None
else
dim
project_in
=
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
ModuleList
([]
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
))
# project in
self
.
net
.
append
(
GEGLU
(
dim
,
inner_dim
))
# project dropout
self
.
net
.
append
(
nn
.
Dropout
(
dropout
))
# project out
self
.
net
.
append
(
nn
.
Linear
(
inner_dim
,
dim_out
))
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
return
self
.
net
(
hidden_states
)
for
module
in
self
.
net
:
hidden_states
=
module
(
hidden_states
)
return
hidden_states
# feedforward
# feedforward
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment