Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
ea39cd7e
Unverified
Commit
ea39cd7e
authored
Apr 11, 2023
by
Will Berman
Committed by
GitHub
Apr 11, 2023
Browse files
Attn added kv processor torch 2.0 block (#3023)
add AttnAddedKVProcessor2_0 block
parent
98c5e5da
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
12 deletions
+109
-12
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+73
-5
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+19
-4
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+11
-2
tests/pipelines/unclip/test_unclip_image_variation.py
tests/pipelines/unclip/test_unclip_image_variation.py
+6
-1
No files found.
src/diffusers/models/attention_processor.py
View file @
ea39cd7e
...
...
@@ -255,11 +255,15 @@ class Attention(nn.Module):
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
head_to_batch_dim
(
self
,
tensor
):
def
head_to_batch_dim
(
self
,
tensor
,
out_dim
=
3
):
head_size
=
self
.
heads
batch_size
,
seq_len
,
dim
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
)
if
out_dim
==
3
:
tensor
=
tensor
.
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
get_attention_scores
(
self
,
query
,
key
,
attention_mask
=
None
):
...
...
@@ -293,7 +297,7 @@ class Attention(nn.Module):
return
attention_probs
def
prepare_attention_mask
(
self
,
attention_mask
,
target_length
,
batch_size
=
None
):
def
prepare_attention_mask
(
self
,
attention_mask
,
target_length
,
batch_size
=
None
,
out_dim
=
3
):
if
batch_size
is
None
:
deprecate
(
"batch_size=None"
,
...
...
@@ -320,8 +324,13 @@ class Attention(nn.Module):
else
:
attention_mask
=
F
.
pad
(
attention_mask
,
(
0
,
target_length
),
value
=
0.0
)
if
attention_mask
.
shape
[
0
]
<
batch_size
*
head_size
:
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
0
)
if
out_dim
==
3
:
if
attention_mask
.
shape
[
0
]
<
batch_size
*
head_size
:
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
0
)
elif
out_dim
==
4
:
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
1
)
return
attention_mask
def
norm_encoder_hidden_states
(
self
,
encoder_hidden_states
):
...
...
@@ -499,6 +508,64 @@ class AttnAddedKVProcessor:
return
hidden_states
class
AttnAddedKVProcessor2_0
:
def
__init__
(
self
):
if
not
hasattr
(
F
,
"scaled_dot_product_attention"
):
raise
ImportError
(
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def
__call__
(
self
,
attn
:
Attention
,
hidden_states
,
encoder_hidden_states
=
None
,
attention_mask
=
None
):
residual
=
hidden_states
hidden_states
=
hidden_states
.
view
(
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
1
],
-
1
).
transpose
(
1
,
2
)
batch_size
,
sequence_length
,
_
=
hidden_states
.
shape
attention_mask
=
attn
.
prepare_attention_mask
(
attention_mask
,
sequence_length
,
batch_size
,
out_dim
=
4
)
if
encoder_hidden_states
is
None
:
encoder_hidden_states
=
hidden_states
elif
attn
.
norm_cross
:
encoder_hidden_states
=
attn
.
norm_encoder_hidden_states
(
encoder_hidden_states
)
hidden_states
=
attn
.
group_norm
(
hidden_states
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
query
=
attn
.
to_q
(
hidden_states
)
query
=
attn
.
head_to_batch_dim
(
query
,
out_dim
=
4
)
encoder_hidden_states_key_proj
=
attn
.
add_k_proj
(
encoder_hidden_states
)
encoder_hidden_states_value_proj
=
attn
.
add_v_proj
(
encoder_hidden_states
)
encoder_hidden_states_key_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_key_proj
,
out_dim
=
4
)
encoder_hidden_states_value_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_value_proj
,
out_dim
=
4
)
if
not
attn
.
only_cross_attention
:
key
=
attn
.
to_k
(
hidden_states
)
value
=
attn
.
to_v
(
hidden_states
)
key
=
attn
.
head_to_batch_dim
(
key
,
out_dim
=
4
)
value
=
attn
.
head_to_batch_dim
(
value
,
out_dim
=
4
)
key
=
torch
.
cat
([
encoder_hidden_states_key_proj
,
key
],
dim
=
2
)
value
=
torch
.
cat
([
encoder_hidden_states_value_proj
,
value
],
dim
=
2
)
else
:
key
=
encoder_hidden_states_key_proj
value
=
encoder_hidden_states_value_proj
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
residual
.
shape
[
1
])
# linear proj
hidden_states
=
attn
.
to_out
[
0
](
hidden_states
)
# dropout
hidden_states
=
attn
.
to_out
[
1
](
hidden_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
residual
.
shape
)
hidden_states
=
hidden_states
+
residual
return
hidden_states
class
XFormersAttnProcessor
:
def
__init__
(
self
,
attention_op
:
Optional
[
Callable
]
=
None
):
self
.
attention_op
=
attention_op
...
...
@@ -764,6 +831,7 @@ AttentionProcessor = Union[
SlicedAttnProcessor
,
AttnAddedKVProcessor
,
SlicedAttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
,
LoRAAttnProcessor
,
LoRAXFormersAttnProcessor
,
]
src/diffusers/models/unet_2d_blocks.py
View file @
ea39cd7e
...
...
@@ -15,10 +15,11 @@ from typing import Optional
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
.attention
import
AdaGroupNorm
,
AttentionBlock
from
.attention_processor
import
Attention
,
AttnAddedKVProcessor
from
.attention_processor
import
Attention
,
AttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
from
.dual_transformer_2d
import
DualTransformer2DModel
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
KDownsample2D
,
KUpsample2D
,
ResnetBlock2D
,
Upsample2D
from
.transformer_2d
import
Transformer2DModel
...
...
@@ -612,6 +613,10 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attentions
=
[]
for
_
in
range
(
num_layers
):
processor
=
(
AttnAddedKVProcessor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
AttnAddedKVProcessor
()
)
attentions
.
append
(
Attention
(
query_dim
=
in_channels
,
...
...
@@ -624,7 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
upcast_softmax
=
True
,
only_cross_attention
=
only_cross_attention
,
cross_attention_norm
=
cross_attention_norm
,
processor
=
AttnAddedKVP
rocessor
()
,
processor
=
p
rocessor
,
)
)
resnets
.
append
(
...
...
@@ -1396,6 +1401,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
skip_time_act
=
skip_time_act
,
)
)
processor
=
(
AttnAddedKVProcessor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
AttnAddedKVProcessor
()
)
attentions
.
append
(
Attention
(
query_dim
=
out_channels
,
...
...
@@ -1408,7 +1418,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
upcast_softmax
=
True
,
only_cross_attention
=
only_cross_attention
,
cross_attention_norm
=
cross_attention_norm
,
processor
=
AttnAddedKVP
rocessor
()
,
processor
=
p
rocessor
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
...
...
@@ -2399,6 +2409,11 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
skip_time_act
=
skip_time_act
,
)
)
processor
=
(
AttnAddedKVProcessor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
AttnAddedKVProcessor
()
)
attentions
.
append
(
Attention
(
query_dim
=
out_channels
,
...
...
@@ -2411,7 +2426,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
upcast_softmax
=
True
,
only_cross_attention
=
only_cross_attention
,
cross_attention_norm
=
cross_attention_norm
,
processor
=
AttnAddedKVP
rocessor
()
,
processor
=
p
rocessor
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
ea39cd7e
...
...
@@ -8,7 +8,12 @@ import torch.nn.functional as F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...models
import
ModelMixin
from
...models.attention
import
Attention
from
...models.attention_processor
import
AttentionProcessor
,
AttnAddedKVProcessor
,
AttnProcessor
from
...models.attention_processor
import
(
AttentionProcessor
,
AttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
,
AttnProcessor
,
)
from
...models.dual_transformer_2d
import
DualTransformer2DModel
from
...models.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
...models.transformer_2d
import
Transformer2DModel
...
...
@@ -1545,6 +1550,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attentions
=
[]
for
_
in
range
(
num_layers
):
processor
=
(
AttnAddedKVProcessor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
AttnAddedKVProcessor
()
)
attentions
.
append
(
Attention
(
query_dim
=
in_channels
,
...
...
@@ -1557,7 +1566,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
upcast_softmax
=
True
,
only_cross_attention
=
only_cross_attention
,
cross_attention_norm
=
cross_attention_norm
,
processor
=
AttnAddedKVP
rocessor
()
,
processor
=
p
rocessor
,
)
)
resnets
.
append
(
...
...
tests/pipelines/unclip/test_unclip_image_variation.py
View file @
ea39cd7e
...
...
@@ -421,7 +421,12 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
def
test_attention_slicing_forward_pass
(
self
):
test_max_difference
=
torch_device
==
"cpu"
self
.
_test_attention_slicing_forward_pass
(
test_max_difference
=
test_max_difference
)
# Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
expected_max_diff
=
1e-2
self
.
_test_attention_slicing_forward_pass
(
test_max_difference
=
test_max_difference
,
expected_max_diff
=
expected_max_diff
)
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment