Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
2d9ccf39
Unverified
Commit
2d9ccf39
authored
Aug 23, 2024
by
Sayak Paul
Committed by
GitHub
Aug 23, 2024
Browse files
[Core] fuse_qkv_projection() to Flux (#9185)
* start fusing flux. * test * finish fusion * fix-copues
parent
960c149c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
245 additions
and
3 deletions
+245
-3
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+94
-0
src/diffusers/models/transformers/transformer_flux.py
src/diffusers/models/transformers/transformer_flux.py
+106
-1
tests/pipelines/flux/test_pipeline_flux.py
tests/pipelines/flux/test_pipeline_flux.py
+45
-2
No files found.
src/diffusers/models/attention_processor.py
View file @
2d9ccf39
...
...
@@ -1783,6 +1783,100 @@ class FluxAttnProcessor2_0:
return
hidden_states
class
FusedFluxAttnProcessor2_0
:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def
__init__
(
self
):
if
not
hasattr
(
F
,
"scaled_dot_product_attention"
):
raise
ImportError
(
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def
__call__
(
self
,
attn
:
Attention
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
torch
.
FloatTensor
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
image_rotary_emb
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
FloatTensor
:
batch_size
,
_
,
_
=
hidden_states
.
shape
if
encoder_hidden_states
is
None
else
encoder_hidden_states
.
shape
# `sample` projections.
qkv
=
attn
.
to_qkv
(
hidden_states
)
split_size
=
qkv
.
shape
[
-
1
]
//
3
query
,
key
,
value
=
torch
.
split
(
qkv
,
split_size
,
dim
=-
1
)
inner_dim
=
key
.
shape
[
-
1
]
head_dim
=
inner_dim
//
attn
.
heads
query
=
query
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
if
attn
.
norm_q
is
not
None
:
query
=
attn
.
norm_q
(
query
)
if
attn
.
norm_k
is
not
None
:
key
=
attn
.
norm_k
(
key
)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if
encoder_hidden_states
is
not
None
:
encoder_qkv
=
attn
.
to_added_qkv
(
encoder_hidden_states
)
split_size
=
encoder_qkv
.
shape
[
-
1
]
//
3
(
encoder_hidden_states_query_proj
,
encoder_hidden_states_key_proj
,
encoder_hidden_states_value_proj
,
)
=
torch
.
split
(
encoder_qkv
,
split_size
,
dim
=-
1
)
encoder_hidden_states_query_proj
=
encoder_hidden_states_query_proj
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
encoder_hidden_states_key_proj
=
encoder_hidden_states_key_proj
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
encoder_hidden_states_value_proj
=
encoder_hidden_states_value_proj
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
if
attn
.
norm_added_q
is
not
None
:
encoder_hidden_states_query_proj
=
attn
.
norm_added_q
(
encoder_hidden_states_query_proj
)
if
attn
.
norm_added_k
is
not
None
:
encoder_hidden_states_key_proj
=
attn
.
norm_added_k
(
encoder_hidden_states_key_proj
)
# attention
query
=
torch
.
cat
([
encoder_hidden_states_query_proj
,
query
],
dim
=
2
)
key
=
torch
.
cat
([
encoder_hidden_states_key_proj
,
key
],
dim
=
2
)
value
=
torch
.
cat
([
encoder_hidden_states_value_proj
,
value
],
dim
=
2
)
if
image_rotary_emb
is
not
None
:
from
.embeddings
import
apply_rotary_emb
query
=
apply_rotary_emb
(
query
,
image_rotary_emb
)
key
=
apply_rotary_emb
(
key
,
image_rotary_emb
)
hidden_states
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
dropout_p
=
0.0
,
is_causal
=
False
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
attn
.
heads
*
head_dim
)
hidden_states
=
hidden_states
.
to
(
query
.
dtype
)
if
encoder_hidden_states
is
not
None
:
encoder_hidden_states
,
hidden_states
=
(
hidden_states
[:,
:
encoder_hidden_states
.
shape
[
1
]],
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:],
)
# linear proj
hidden_states
=
attn
.
to_out
[
0
](
hidden_states
)
# dropout
hidden_states
=
attn
.
to_out
[
1
](
hidden_states
)
encoder_hidden_states
=
attn
.
to_add_out
(
encoder_hidden_states
)
return
hidden_states
,
encoder_hidden_states
else
:
return
hidden_states
class
CogVideoXAttnProcessor2_0
:
r
"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
...
...
src/diffusers/models/transformers/transformer_flux.py
View file @
2d9ccf39
...
...
@@ -23,7 +23,12 @@ import torch.nn.functional as F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...loaders
import
FromOriginalModelMixin
,
PeftAdapterMixin
from
...models.attention
import
FeedForward
from
...models.attention_processor
import
Attention
,
FluxAttnProcessor2_0
from
...models.attention_processor
import
(
Attention
,
AttentionProcessor
,
FluxAttnProcessor2_0
,
FusedFluxAttnProcessor2_0
,
)
from
...models.modeling_utils
import
ModelMixin
from
...models.normalization
import
AdaLayerNormContinuous
,
AdaLayerNormZero
,
AdaLayerNormZeroSingle
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_version
,
logging
,
scale_lora_layers
,
unscale_lora_layers
...
...
@@ -276,6 +281,106 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self
.
gradient_checkpointing
=
False
@
property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def
attn_processors
(
self
)
->
Dict
[
str
,
AttentionProcessor
]:
r
"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors
=
{}
def
fn_recursive_add_processors
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processors
:
Dict
[
str
,
AttentionProcessor
]):
if
hasattr
(
module
,
"get_processor"
):
processors
[
f
"
{
name
}
.processor"
]
=
module
.
get_processor
()
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_add_processors
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processors
)
return
processors
for
name
,
module
in
self
.
named_children
():
fn_recursive_add_processors
(
name
,
module
,
processors
)
return
processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def
set_attn_processor
(
self
,
processor
:
Union
[
AttentionProcessor
,
Dict
[
str
,
AttentionProcessor
]]):
r
"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count
=
len
(
self
.
attn_processors
.
keys
())
if
isinstance
(
processor
,
dict
)
and
len
(
processor
)
!=
count
:
raise
ValueError
(
f
"A dict of processors was passed, but the number of processors
{
len
(
processor
)
}
does not match the"
f
" number of attention layers:
{
count
}
. Please make sure to pass
{
count
}
processor classes."
)
def
fn_recursive_attn_processor
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processor
):
if
hasattr
(
module
,
"set_processor"
):
if
not
isinstance
(
processor
,
dict
):
module
.
set_processor
(
processor
)
else
:
module
.
set_processor
(
processor
.
pop
(
f
"
{
name
}
.processor"
))
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_attn_processor
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processor
)
for
name
,
module
in
self
.
named_children
():
fn_recursive_attn_processor
(
name
,
module
,
processor
)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def
fuse_qkv_projections
(
self
):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self
.
original_attn_processors
=
None
for
_
,
attn_processor
in
self
.
attn_processors
.
items
():
if
"Added"
in
str
(
attn_processor
.
__class__
.
__name__
):
raise
ValueError
(
"`fuse_qkv_projections()` is not supported for models having added KV projections."
)
self
.
original_attn_processors
=
self
.
attn_processors
for
module
in
self
.
modules
():
if
isinstance
(
module
,
Attention
):
module
.
fuse_projections
(
fuse
=
True
)
self
.
set_attn_processor
(
FusedFluxAttnProcessor2_0
())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def
unfuse_qkv_projections
(
self
):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if
self
.
original_attn_processors
is
not
None
:
self
.
set_attn_processor
(
self
.
original_attn_processors
)
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
...
...
tests/pipelines/flux/test_pipeline_flux.py
View file @
2d9ccf39
...
...
@@ -13,10 +13,13 @@ from diffusers.utils.testing_utils import (
torch_device
,
)
from
..test_pipelines_common
import
PipelineTesterMixin
from
..test_pipelines_common
import
(
PipelineTesterMixin
,
check_qkv_fusion_matches_attn_procs_length
,
check_qkv_fusion_processors_exist
,
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"Flux has a float64 operation which is not supported in MPS."
)
class
FluxPipelineFastTests
(
unittest
.
TestCase
,
PipelineTesterMixin
):
pipeline_class
=
FluxPipeline
params
=
frozenset
([
"prompt"
,
"height"
,
"width"
,
"guidance_scale"
,
"prompt_embeds"
,
"pooled_prompt_embeds"
])
...
...
@@ -143,6 +146,46 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
max_diff
=
np
.
abs
(
output_with_prompt
-
output_with_embeds
).
max
()
assert
max_diff
<
1e-4
def
test_fused_qkv_projections
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
original_image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe
.
transformer
.
fuse_qkv_projections
()
assert
check_qkv_fusion_processors_exist
(
pipe
.
transformer
),
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert
check_qkv_fusion_matches_attn_procs_length
(
pipe
.
transformer
,
pipe
.
transformer
.
original_attn_processors
),
"Something wrong with the attention processors concerning the fused QKV projections."
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice_fused
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
pipe
.
transformer
.
unfuse_qkv_projections
()
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice_disabled
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
np
.
allclose
(
original_image_slice
,
image_slice_fused
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fusion of QKV projections shouldn't affect the outputs."
assert
np
.
allclose
(
image_slice_fused
,
image_slice_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert
np
.
allclose
(
original_image_slice
,
image_slice_disabled
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Original outputs should match when fused QKV projections are disabled."
@
slow
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment