Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ea39cd7e
Unverified
Commit
ea39cd7e
authored
Apr 11, 2023
by
Will Berman
Committed by
GitHub
Apr 11, 2023
Browse files
Attn added kv processor torch 2.0 block (#3023)
add AttnAddedKVProcessor2_0 block
parent
98c5e5da
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
12 deletions
+109
-12
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+73
-5
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+19
-4
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+11
-2
tests/pipelines/unclip/test_unclip_image_variation.py
tests/pipelines/unclip/test_unclip_image_variation.py
+6
-1
No files found.
src/diffusers/models/attention_processor.py
View file @
ea39cd7e
...
...
@@ -255,11 +255,15 @@ class Attention(nn.Module):
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
head_to_batch_dim
(
self
,
tensor
):
def
head_to_batch_dim
(
self
,
tensor
,
out_dim
=
3
):
head_size
=
self
.
heads
batch_size
,
seq_len
,
dim
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
)
if
out_dim
==
3
:
tensor
=
tensor
.
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
get_attention_scores
(
self
,
query
,
key
,
attention_mask
=
None
):
...
...
@@ -293,7 +297,7 @@ class Attention(nn.Module):
return
attention_probs
def
prepare_attention_mask
(
self
,
attention_mask
,
target_length
,
batch_size
=
None
):
def
prepare_attention_mask
(
self
,
attention_mask
,
target_length
,
batch_size
=
None
,
out_dim
=
3
):
if
batch_size
is
None
:
deprecate
(
"batch_size=None"
,
...
...
@@ -320,8 +324,13 @@ class Attention(nn.Module):
else
:
attention_mask
=
F
.
pad
(
attention_mask
,
(
0
,
target_length
),
value
=
0.0
)
if
attention_mask
.
shape
[
0
]
<
batch_size
*
head_size
:
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
0
)
if
out_dim
==
3
:
if
attention_mask
.
shape
[
0
]
<
batch_size
*
head_size
:
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
0
)
elif
out_dim
==
4
:
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
1
)
return
attention_mask
def
norm_encoder_hidden_states
(
self
,
encoder_hidden_states
):
...
...
@@ -499,6 +508,64 @@ class AttnAddedKVProcessor:
return
hidden_states
class
AttnAddedKVProcessor2_0
:
def
__init__
(
self
):
if
not
hasattr
(
F
,
"scaled_dot_product_attention"
):
raise
ImportError
(
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def
__call__
(
self
,
attn
:
Attention
,
hidden_states
,
encoder_hidden_states
=
None
,
attention_mask
=
None
):
residual
=
hidden_states
hidden_states
=
hidden_states
.
view
(
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
1
],
-
1
).
transpose
(
1
,
2
)
batch_size
,
sequence_length
,
_
=
hidden_states
.
shape
attention_mask
=
attn
.
prepare_attention_mask
(
attention_mask
,
sequence_length
,
batch_size
,
out_dim
=
4
)
if
encoder_hidden_states
is
None
:
encoder_hidden_states
=
hidden_states
elif
attn
.
norm_cross
:
encoder_hidden_states
=
attn
.
norm_encoder_hidden_states
(
encoder_hidden_states
)
hidden_states
=
attn
.
group_norm
(
hidden_states
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
query
=
attn
.
to_q
(
hidden_states
)
query
=
attn
.
head_to_batch_dim
(
query
,
out_dim
=
4
)
encoder_hidden_states_key_proj
=
attn
.
add_k_proj
(
encoder_hidden_states
)
encoder_hidden_states_value_proj
=
attn
.
add_v_proj
(
encoder_hidden_states
)
encoder_hidden_states_key_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_key_proj
,
out_dim
=
4
)
encoder_hidden_states_value_proj
=
attn
.
head_to_batch_dim
(
encoder_hidden_states_value_proj
,
out_dim
=
4
)
if
not
attn
.
only_cross_attention
:
key
=
attn
.
to_k
(
hidden_states
)
value
=
attn
.
to_v
(
hidden_states
)
key
=
attn
.
head_to_batch_dim
(
key
,
out_dim
=
4
)
value
=
attn
.
head_to_batch_dim
(
value
,
out_dim
=
4
)
key
=
torch
.
cat
([
encoder_hidden_states_key_proj
,
key
],
dim
=
2
)
value
=
torch
.
cat
([
encoder_hidden_states_value_proj
,
value
],
dim
=
2
)
else
:
key
=
encoder_hidden_states_key_proj
value
=
encoder_hidden_states_value_proj
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
residual
.
shape
[
1
])
# linear proj
hidden_states
=
attn
.
to_out
[
0
](
hidden_states
)
# dropout
hidden_states
=
attn
.
to_out
[
1
](
hidden_states
)
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
).
reshape
(
residual
.
shape
)
hidden_states
=
hidden_states
+
residual
return
hidden_states
class
XFormersAttnProcessor
:
def
__init__
(
self
,
attention_op
:
Optional
[
Callable
]
=
None
):
self
.
attention_op
=
attention_op
...
...
@@ -764,6 +831,7 @@ AttentionProcessor = Union[
SlicedAttnProcessor
,
AttnAddedKVProcessor
,
SlicedAttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
,
LoRAAttnProcessor
,
LoRAXFormersAttnProcessor
,
]
src/diffusers/models/unet_2d_blocks.py
View file @
ea39cd7e
...
...
@@ -15,10 +15,11 @@ from typing import Optional
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
.attention
import
AdaGroupNorm
,
AttentionBlock
from
.attention_processor
import
Attention
,
AttnAddedKVProcessor
from
.attention_processor
import
Attention
,
AttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
from
.dual_transformer_2d
import
DualTransformer2DModel
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
KDownsample2D
,
KUpsample2D
,
ResnetBlock2D
,
Upsample2D
from
.transformer_2d
import
Transformer2DModel
...
...
@@ -612,6 +613,10 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attentions
=
[]
for
_
in
range
(
num_layers
):
processor
=
(
AttnAddedKVProcessor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
AttnAddedKVProcessor
()
)
attentions
.
append
(
Attention
(
query_dim
=
in_channels
,
...
...
@@ -624,7 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
upcast_softmax
=
True
,
only_cross_attention
=
only_cross_attention
,
cross_attention_norm
=
cross_attention_norm
,
processor
=
AttnAddedKVP
rocessor
()
,
processor
=
p
rocessor
,
)
)
resnets
.
append
(
...
...
@@ -1396,6 +1401,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
skip_time_act
=
skip_time_act
,
)
)
processor
=
(
AttnAddedKVProcessor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
AttnAddedKVProcessor
()
)
attentions
.
append
(
Attention
(
query_dim
=
out_channels
,
...
...
@@ -1408,7 +1418,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
upcast_softmax
=
True
,
only_cross_attention
=
only_cross_attention
,
cross_attention_norm
=
cross_attention_norm
,
processor
=
AttnAddedKVP
rocessor
()
,
processor
=
p
rocessor
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
...
...
@@ -2399,6 +2409,11 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
skip_time_act
=
skip_time_act
,
)
)
processor
=
(
AttnAddedKVProcessor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
AttnAddedKVProcessor
()
)
attentions
.
append
(
Attention
(
query_dim
=
out_channels
,
...
...
@@ -2411,7 +2426,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
upcast_softmax
=
True
,
only_cross_attention
=
only_cross_attention
,
cross_attention_norm
=
cross_attention_norm
,
processor
=
AttnAddedKVP
rocessor
()
,
processor
=
p
rocessor
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
ea39cd7e
...
...
@@ -8,7 +8,12 @@ import torch.nn.functional as F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...models
import
ModelMixin
from
...models.attention
import
Attention
from
...models.attention_processor
import
AttentionProcessor
,
AttnAddedKVProcessor
,
AttnProcessor
from
...models.attention_processor
import
(
AttentionProcessor
,
AttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
,
AttnProcessor
,
)
from
...models.dual_transformer_2d
import
DualTransformer2DModel
from
...models.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
...models.transformer_2d
import
Transformer2DModel
...
...
@@ -1545,6 +1550,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attentions
=
[]
for
_
in
range
(
num_layers
):
processor
=
(
AttnAddedKVProcessor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
AttnAddedKVProcessor
()
)
attentions
.
append
(
Attention
(
query_dim
=
in_channels
,
...
...
@@ -1557,7 +1566,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
upcast_softmax
=
True
,
only_cross_attention
=
only_cross_attention
,
cross_attention_norm
=
cross_attention_norm
,
processor
=
AttnAddedKVP
rocessor
()
,
processor
=
p
rocessor
,
)
)
resnets
.
append
(
...
...
tests/pipelines/unclip/test_unclip_image_variation.py
View file @
ea39cd7e
...
...
@@ -421,7 +421,12 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
def
test_attention_slicing_forward_pass
(
self
):
test_max_difference
=
torch_device
==
"cpu"
self
.
_test_attention_slicing_forward_pass
(
test_max_difference
=
test_max_difference
)
# Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
expected_max_diff
=
1e-2
self
.
_test_attention_slicing_forward_pass
(
test_max_difference
=
test_max_difference
,
expected_max_diff
=
expected_max_diff
)
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment