Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
2d9ccf39
Unverified
Commit
2d9ccf39
authored
Aug 23, 2024
by
Sayak Paul
Committed by
GitHub
Aug 23, 2024
Browse files
[Core] fuse_qkv_projection() to Flux (#9185)
* start fusing flux. * test * finish fusion * fix-copues
parent
960c149c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
245 additions
and
3 deletions
+245
-3
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+94
-0
src/diffusers/models/transformers/transformer_flux.py
src/diffusers/models/transformers/transformer_flux.py
+106
-1
tests/pipelines/flux/test_pipeline_flux.py
tests/pipelines/flux/test_pipeline_flux.py
+45
-2
No files found.
src/diffusers/models/attention_processor.py
View file @
2d9ccf39
...
@@ -1783,6 +1783,100 @@ class FluxAttnProcessor2_0:
...
@@ -1783,6 +1783,100 @@ class FluxAttnProcessor2_0:
return
hidden_states
return
hidden_states
class
FusedFluxAttnProcessor2_0
:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def
__init__
(
self
):
if
not
hasattr
(
F
,
"scaled_dot_product_attention"
):
raise
ImportError
(
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def
__call__
(
self
,
attn
:
Attention
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
:
torch
.
FloatTensor
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
image_rotary_emb
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
FloatTensor
:
batch_size
,
_
,
_
=
hidden_states
.
shape
if
encoder_hidden_states
is
None
else
encoder_hidden_states
.
shape
# `sample` projections.
qkv
=
attn
.
to_qkv
(
hidden_states
)
split_size
=
qkv
.
shape
[
-
1
]
//
3
query
,
key
,
value
=
torch
.
split
(
qkv
,
split_size
,
dim
=-
1
)
inner_dim
=
key
.
shape
[
-
1
]
head_dim
=
inner_dim
//
attn
.
heads
query
=
query
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
if
attn
.
norm_q
is
not
None
:
query
=
attn
.
norm_q
(
query
)
if
attn
.
norm_k
is
not
None
:
key
=
attn
.
norm_k
(
key
)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if
encoder_hidden_states
is
not
None
:
encoder_qkv
=
attn
.
to_added_qkv
(
encoder_hidden_states
)
split_size
=
encoder_qkv
.
shape
[
-
1
]
//
3
(
encoder_hidden_states_query_proj
,
encoder_hidden_states_key_proj
,
encoder_hidden_states_value_proj
,
)
=
torch
.
split
(
encoder_qkv
,
split_size
,
dim
=-
1
)
encoder_hidden_states_query_proj
=
encoder_hidden_states_query_proj
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
encoder_hidden_states_key_proj
=
encoder_hidden_states_key_proj
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
encoder_hidden_states_value_proj
=
encoder_hidden_states_value_proj
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
if
attn
.
norm_added_q
is
not
None
:
encoder_hidden_states_query_proj
=
attn
.
norm_added_q
(
encoder_hidden_states_query_proj
)
if
attn
.
norm_added_k
is
not
None
:
encoder_hidden_states_key_proj
=
attn
.
norm_added_k
(
encoder_hidden_states_key_proj
)
# attention
query
=
torch
.
cat
([
encoder_hidden_states_query_proj
,
query
],
dim
=
2
)
key
=
torch
.
cat
([
encoder_hidden_states_key_proj
,
key
],
dim
=
2
)
value
=
torch
.
cat
([
encoder_hidden_states_value_proj
,
value
],
dim
=
2
)
if
image_rotary_emb
is
not
None
:
from
.embeddings
import
apply_rotary_emb
query
=
apply_rotary_emb
(
query
,
image_rotary_emb
)
key
=
apply_rotary_emb
(
key
,
image_rotary_emb
)
hidden_states
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
dropout_p
=
0.0
,
is_causal
=
False
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
attn
.
heads
*
head_dim
)
hidden_states
=
hidden_states
.
to
(
query
.
dtype
)
if
encoder_hidden_states
is
not
None
:
encoder_hidden_states
,
hidden_states
=
(
hidden_states
[:,
:
encoder_hidden_states
.
shape
[
1
]],
hidden_states
[:,
encoder_hidden_states
.
shape
[
1
]
:],
)
# linear proj
hidden_states
=
attn
.
to_out
[
0
](
hidden_states
)
# dropout
hidden_states
=
attn
.
to_out
[
1
](
hidden_states
)
encoder_hidden_states
=
attn
.
to_add_out
(
encoder_hidden_states
)
return
hidden_states
,
encoder_hidden_states
else
:
return
hidden_states
class
CogVideoXAttnProcessor2_0
:
class
CogVideoXAttnProcessor2_0
:
r
"""
r
"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
...
...
src/diffusers/models/transformers/transformer_flux.py
View file @
2d9ccf39
...
@@ -23,7 +23,12 @@ import torch.nn.functional as F
...
@@ -23,7 +23,12 @@ import torch.nn.functional as F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...loaders
import
FromOriginalModelMixin
,
PeftAdapterMixin
from
...loaders
import
FromOriginalModelMixin
,
PeftAdapterMixin
from
...models.attention
import
FeedForward
from
...models.attention
import
FeedForward
from
...models.attention_processor
import
Attention
,
FluxAttnProcessor2_0
from
...models.attention_processor
import
(
Attention
,
AttentionProcessor
,
FluxAttnProcessor2_0
,
FusedFluxAttnProcessor2_0
,
)
from
...models.modeling_utils
import
ModelMixin
from
...models.modeling_utils
import
ModelMixin
from
...models.normalization
import
AdaLayerNormContinuous
,
AdaLayerNormZero
,
AdaLayerNormZeroSingle
from
...models.normalization
import
AdaLayerNormContinuous
,
AdaLayerNormZero
,
AdaLayerNormZeroSingle
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_version
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_version
,
logging
,
scale_lora_layers
,
unscale_lora_layers
...
@@ -276,6 +281,106 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
...
@@ -276,6 +281,106 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self
.
gradient_checkpointing
=
False
self
.
gradient_checkpointing
=
False
@
property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def
attn_processors
(
self
)
->
Dict
[
str
,
AttentionProcessor
]:
r
"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors
=
{}
def
fn_recursive_add_processors
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processors
:
Dict
[
str
,
AttentionProcessor
]):
if
hasattr
(
module
,
"get_processor"
):
processors
[
f
"
{
name
}
.processor"
]
=
module
.
get_processor
()
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_add_processors
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processors
)
return
processors
for
name
,
module
in
self
.
named_children
():
fn_recursive_add_processors
(
name
,
module
,
processors
)
return
processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def
set_attn_processor
(
self
,
processor
:
Union
[
AttentionProcessor
,
Dict
[
str
,
AttentionProcessor
]]):
r
"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count
=
len
(
self
.
attn_processors
.
keys
())
if
isinstance
(
processor
,
dict
)
and
len
(
processor
)
!=
count
:
raise
ValueError
(
f
"A dict of processors was passed, but the number of processors
{
len
(
processor
)
}
does not match the"
f
" number of attention layers:
{
count
}
. Please make sure to pass
{
count
}
processor classes."
)
def
fn_recursive_attn_processor
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processor
):
if
hasattr
(
module
,
"set_processor"
):
if
not
isinstance
(
processor
,
dict
):
module
.
set_processor
(
processor
)
else
:
module
.
set_processor
(
processor
.
pop
(
f
"
{
name
}
.processor"
))
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_attn_processor
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processor
)
for
name
,
module
in
self
.
named_children
():
fn_recursive_attn_processor
(
name
,
module
,
processor
)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def
fuse_qkv_projections
(
self
):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self
.
original_attn_processors
=
None
for
_
,
attn_processor
in
self
.
attn_processors
.
items
():
if
"Added"
in
str
(
attn_processor
.
__class__
.
__name__
):
raise
ValueError
(
"`fuse_qkv_projections()` is not supported for models having added KV projections."
)
self
.
original_attn_processors
=
self
.
attn_processors
for
module
in
self
.
modules
():
if
isinstance
(
module
,
Attention
):
module
.
fuse_projections
(
fuse
=
True
)
self
.
set_attn_processor
(
FusedFluxAttnProcessor2_0
())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def
unfuse_qkv_projections
(
self
):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if
self
.
original_attn_processors
is
not
None
:
self
.
set_attn_processor
(
self
.
original_attn_processors
)
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
module
.
gradient_checkpointing
=
value
...
...
tests/pipelines/flux/test_pipeline_flux.py
View file @
2d9ccf39
...
@@ -13,10 +13,13 @@ from diffusers.utils.testing_utils import (
...
@@ -13,10 +13,13 @@ from diffusers.utils.testing_utils import (
torch_device
,
torch_device
,
)
)
from
..test_pipelines_common
import
PipelineTesterMixin
from
..test_pipelines_common
import
(
PipelineTesterMixin
,
check_qkv_fusion_matches_attn_procs_length
,
check_qkv_fusion_processors_exist
,
)
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"Flux has a float64 operation which is not supported in MPS."
)
class
FluxPipelineFastTests
(
unittest
.
TestCase
,
PipelineTesterMixin
):
class
FluxPipelineFastTests
(
unittest
.
TestCase
,
PipelineTesterMixin
):
pipeline_class
=
FluxPipeline
pipeline_class
=
FluxPipeline
params
=
frozenset
([
"prompt"
,
"height"
,
"width"
,
"guidance_scale"
,
"prompt_embeds"
,
"pooled_prompt_embeds"
])
params
=
frozenset
([
"prompt"
,
"height"
,
"width"
,
"guidance_scale"
,
"prompt_embeds"
,
"pooled_prompt_embeds"
])
...
@@ -143,6 +146,46 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
...
@@ -143,6 +146,46 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
max_diff
=
np
.
abs
(
output_with_prompt
-
output_with_embeds
).
max
()
max_diff
=
np
.
abs
(
output_with_prompt
-
output_with_embeds
).
max
()
assert
max_diff
<
1e-4
assert
max_diff
<
1e-4
def
test_fused_qkv_projections
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
original_image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe
.
transformer
.
fuse_qkv_projections
()
assert
check_qkv_fusion_processors_exist
(
pipe
.
transformer
),
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert
check_qkv_fusion_matches_attn_procs_length
(
pipe
.
transformer
,
pipe
.
transformer
.
original_attn_processors
),
"Something wrong with the attention processors concerning the fused QKV projections."
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice_fused
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
pipe
.
transformer
.
unfuse_qkv_projections
()
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice_disabled
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
np
.
allclose
(
original_image_slice
,
image_slice_fused
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Fusion of QKV projections shouldn't affect the outputs."
assert
np
.
allclose
(
image_slice_fused
,
image_slice_disabled
,
atol
=
1e-3
,
rtol
=
1e-3
),
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert
np
.
allclose
(
original_image_slice
,
image_slice_disabled
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Original outputs should match when fused QKV projections are disabled."
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment