Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
425192fe
Unverified
Commit
425192fe
authored
Apr 22, 2023
by
Patrick von Platen
Committed by
GitHub
Apr 22, 2023
Browse files
Make sure VAE attention works with Torch 2_0 (#3200)
* Make sure attention works with Torch 2_0 * make style * Fix more
parent
9965cb50
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
12 deletions
+63
-12
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+29
-12
tests/models/test_models_vae.py
tests/models/test_models_vae.py
+34
-0
No files found.
src/diffusers/models/attention.py
View file @
425192fe
...
...
@@ -60,7 +60,6 @@ class AttentionBlock(nn.Module):
self
.
channels
=
channels
self
.
num_heads
=
channels
//
num_head_channels
if
num_head_channels
is
not
None
else
1
self
.
num_head_size
=
num_head_channels
self
.
group_norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
norm_num_groups
,
eps
=
eps
,
affine
=
True
)
# define q,k,v as linear layers
...
...
@@ -74,18 +73,25 @@ class AttentionBlock(nn.Module):
self
.
_use_memory_efficient_attention_xformers
=
False
self
.
_attention_op
=
None
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
def
reshape_heads_to_batch_dim
(
self
,
tensor
,
merge_head_and_batch
=
True
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
num_heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
)
if
merge_head_and_batch
:
tensor
=
tensor
.
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
def
reshape_batch_dim_to_heads
(
self
,
tensor
,
unmerge_head_and_batch
=
True
):
head_size
=
self
.
num_heads
if
unmerge_head_and_batch
:
batch_size
,
seq_len
,
dim
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
else
:
batch_size
,
_
,
seq_len
,
dim
=
tensor
.
shape
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
set_use_memory_efficient_attention_xformers
(
...
...
@@ -134,14 +140,25 @@ class AttentionBlock(nn.Module):
scale
=
1
/
math
.
sqrt
(
self
.
channels
/
self
.
num_heads
)
query_proj
=
self
.
reshape_heads_to_batch_dim
(
query_proj
)
key_proj
=
self
.
reshape_heads_to_batch_dim
(
key_proj
)
value_proj
=
self
.
reshape_heads_to_batch_dim
(
value_proj
)
use_torch_2_0_attn
=
(
hasattr
(
F
,
"scaled_dot_product_attention"
)
and
not
self
.
_use_memory_efficient_attention_xformers
)
query_proj
=
self
.
reshape_heads_to_batch_dim
(
query_proj
,
merge_head_and_batch
=
not
use_torch_2_0_attn
)
key_proj
=
self
.
reshape_heads_to_batch_dim
(
key_proj
,
merge_head_and_batch
=
not
use_torch_2_0_attn
)
value_proj
=
self
.
reshape_heads_to_batch_dim
(
value_proj
,
merge_head_and_batch
=
not
use_torch_2_0_attn
)
if
self
.
_use_memory_efficient_attention_xformers
:
# Memory efficient attention
hidden_states
=
xformers
.
ops
.
memory_efficient_attention
(
query_proj
,
key_proj
,
value_proj
,
attn_bias
=
None
,
op
=
self
.
_attention_op
query_proj
,
key_proj
,
value_proj
,
attn_bias
=
None
,
op
=
self
.
_attention_op
,
scale
=
scale
)
hidden_states
=
hidden_states
.
to
(
query_proj
.
dtype
)
elif
use_torch_2_0_attn
:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states
=
F
.
scaled_dot_product_attention
(
query_proj
,
key_proj
,
value_proj
,
dropout_p
=
0.0
,
is_causal
=
False
)
hidden_states
=
hidden_states
.
to
(
query_proj
.
dtype
)
else
:
...
...
@@ -162,7 +179,7 @@ class AttentionBlock(nn.Module):
hidden_states
=
torch
.
bmm
(
attention_probs
,
value_proj
)
# reshape hidden_states
hidden_states
=
self
.
reshape_batch_dim_to_heads
(
hidden_states
)
hidden_states
=
self
.
reshape_batch_dim_to_heads
(
hidden_states
,
unmerge_head_and_batch
=
not
use_torch_2_0_attn
)
# compute next hidden_states
hidden_states
=
self
.
proj_attn
(
hidden_states
)
...
...
tests/models/test_models_vae.py
View file @
425192fe
...
...
@@ -319,6 +319,40 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
assert
torch_all_close
(
output_slice
,
expected_output_slice
,
atol
=
5e-3
)
@
parameterized
.
expand
([
13
,
16
,
27
])
@
require_torch_gpu
def
test_stable_diffusion_decode_xformers_vs_2_0_fp16
(
self
,
seed
):
model
=
self
.
get_sd_vae_model
(
fp16
=
True
)
encoding
=
self
.
get_sd_image
(
seed
,
shape
=
(
3
,
4
,
64
,
64
),
fp16
=
True
)
with
torch
.
no_grad
():
sample
=
model
.
decode
(
encoding
).
sample
model
.
enable_xformers_memory_efficient_attention
()
with
torch
.
no_grad
():
sample_2
=
model
.
decode
(
encoding
).
sample
assert
list
(
sample
.
shape
)
==
[
3
,
3
,
512
,
512
]
assert
torch_all_close
(
sample
,
sample_2
,
atol
=
1e-1
)
@
parameterized
.
expand
([
13
,
16
,
37
])
@
require_torch_gpu
def
test_stable_diffusion_decode_xformers_vs_2_0
(
self
,
seed
):
model
=
self
.
get_sd_vae_model
()
encoding
=
self
.
get_sd_image
(
seed
,
shape
=
(
3
,
4
,
64
,
64
))
with
torch
.
no_grad
():
sample
=
model
.
decode
(
encoding
).
sample
model
.
enable_xformers_memory_efficient_attention
()
with
torch
.
no_grad
():
sample_2
=
model
.
decode
(
encoding
).
sample
assert
list
(
sample
.
shape
)
==
[
3
,
3
,
512
,
512
]
assert
torch_all_close
(
sample
,
sample_2
,
atol
=
1e-2
)
@
parameterized
.
expand
(
[
# fmt: off
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment