Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8a4c3e50
Unverified
Commit
8a4c3e50
authored
Dec 27, 2022
by
William Held
Committed by
GitHub
Dec 27, 2022
Browse files
Width was typod as weight (#1800)
* Width was typod as weight * Run Black
parent
68e24259
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
9 deletions
+5
-9
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+5
-9
No files found.
src/diffusers/models/attention.py
View file @
8a4c3e50
...
@@ -204,17 +204,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
...
@@ -204,17 +204,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
"""
"""
# 1. Input
# 1. Input
if
self
.
is_input_continuous
:
if
self
.
is_input_continuous
:
batch
,
channel
,
height
,
w
eight
=
hidden_states
.
shape
batch
,
channel
,
height
,
w
idth
=
hidden_states
.
shape
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
if
not
self
.
use_linear_projection
:
if
not
self
.
use_linear_projection
:
hidden_states
=
self
.
proj_in
(
hidden_states
)
hidden_states
=
self
.
proj_in
(
hidden_states
)
inner_dim
=
hidden_states
.
shape
[
1
]
inner_dim
=
hidden_states
.
shape
[
1
]
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
batch
,
height
*
w
eight
,
inner_dim
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
batch
,
height
*
w
idth
,
inner_dim
)
else
:
else
:
inner_dim
=
hidden_states
.
shape
[
1
]
inner_dim
=
hidden_states
.
shape
[
1
]
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
batch
,
height
*
w
eight
,
inner_dim
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
batch
,
height
*
w
idth
,
inner_dim
)
hidden_states
=
self
.
proj_in
(
hidden_states
)
hidden_states
=
self
.
proj_in
(
hidden_states
)
elif
self
.
is_input_vectorized
:
elif
self
.
is_input_vectorized
:
hidden_states
=
self
.
latent_image_embedding
(
hidden_states
)
hidden_states
=
self
.
latent_image_embedding
(
hidden_states
)
...
@@ -231,15 +231,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
...
@@ -231,15 +231,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output
# 3. Output
if
self
.
is_input_continuous
:
if
self
.
is_input_continuous
:
if
not
self
.
use_linear_projection
:
if
not
self
.
use_linear_projection
:
hidden_states
=
(
hidden_states
=
hidden_states
.
reshape
(
batch
,
height
,
width
,
inner_dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
hidden_states
.
reshape
(
batch
,
height
,
weight
,
inner_dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
)
hidden_states
=
self
.
proj_out
(
hidden_states
)
hidden_states
=
self
.
proj_out
(
hidden_states
)
else
:
else
:
hidden_states
=
self
.
proj_out
(
hidden_states
)
hidden_states
=
self
.
proj_out
(
hidden_states
)
hidden_states
=
(
hidden_states
=
hidden_states
.
reshape
(
batch
,
height
,
width
,
inner_dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
hidden_states
.
reshape
(
batch
,
height
,
weight
,
inner_dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
)
output
=
hidden_states
+
residual
output
=
hidden_states
+
residual
elif
self
.
is_input_vectorized
:
elif
self
.
is_input_vectorized
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment