Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4ae54b37
Unverified
Commit
4ae54b37
authored
Mar 13, 2023
by
Patrick von Platen
Committed by
GitHub
Mar 13, 2023
Browse files
[attention] Fix attention (#2656)
* [attention] Fix attention * fix * correct
parent
fa7a5761
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
3 deletions
+6
-3
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+5
-2
tests/pipelines/stable_diffusion/test_stable_diffusion.py
tests/pipelines/stable_diffusion/test_stable_diffusion.py
+1
-1
No files found.
src/diffusers/models/attention.py
View file @
4ae54b37
...
...
@@ -271,9 +271,10 @@ class BasicTransformerBlock(nn.Module):
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
timestep
=
None
,
attention_mask
=
None
,
cross_attention_kwargs
=
None
,
class_labels
=
None
,
):
...
...
@@ -302,12 +303,14 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states
=
(
self
.
norm2
(
hidden_states
,
timestep
)
if
self
.
use_ada_layer_norm
else
self
.
norm2
(
hidden_states
)
)
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here
# 2. Cross-Attention
attn_output
=
self
.
attn2
(
norm_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
,
attention_mask
=
encoder_
attention_mask
,
**
cross_attention_kwargs
,
)
hidden_states
=
attn_output
+
hidden_states
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion.py
View file @
4ae54b37
...
...
@@ -737,7 +737,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
# make sure that more than 4 GB is allocated
mem_bytes
=
torch
.
cuda
.
max_memory_allocated
()
assert
mem_bytes
>
4
e9
assert
mem_bytes
>
5
e9
assert
np
.
abs
(
image_chunked
.
flatten
()
-
image
.
flatten
()).
max
()
<
1e-2
def
test_stable_diffusion_fp16_vs_autocast
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment