Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
14f7b545
Unverified
Commit
14f7b545
authored
Jun 05, 2024
by
Sayak Paul
Committed by
GitHub
Jun 05, 2024
Browse files
[Hunyuan DiT] feat: enable fusing qkv projections when doing attention (#8396)
* feat: introduce qkv fusion for Hunyuan * fix copies
parent
07cd2004
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
140 additions
and
2 deletions
+140
-2
src/diffusers/models/transformers/hunyuan_transformer_2d.py
src/diffusers/models/transformers/hunyuan_transformer_2d.py
+106
-2
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
+34
-0
No files found.
src/diffusers/models/transformers/hunyuan_transformer_2d.py
View file @
14f7b545
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
Optional
from
typing
import
Dict
,
Optional
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
...
@@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from
...utils
import
logging
from
...utils
import
logging
from
...utils.torch_utils
import
maybe_allow_in_graph
from
...utils.torch_utils
import
maybe_allow_in_graph
from
..attention
import
FeedForward
from
..attention
import
FeedForward
from
..attention_processor
import
Attention
,
HunyuanAttnProcessor2_0
from
..attention_processor
import
Attention
,
AttentionProcessor
,
HunyuanAttnProcessor2_0
from
..embeddings
import
(
from
..embeddings
import
(
HunyuanCombinedTimestepTextSizeStyleEmbedding
,
HunyuanCombinedTimestepTextSizeStyleEmbedding
,
PatchEmbed
,
PatchEmbed
,
...
@@ -321,6 +321,110 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
...
@@ -321,6 +321,110 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
self
.
norm_out
=
AdaLayerNormContinuous
(
self
.
inner_dim
,
self
.
inner_dim
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
norm_out
=
AdaLayerNormContinuous
(
self
.
inner_dim
,
self
.
inner_dim
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
proj_out
=
nn
.
Linear
(
self
.
inner_dim
,
patch_size
*
patch_size
*
self
.
out_channels
,
bias
=
True
)
self
.
proj_out
=
nn
.
Linear
(
self
.
inner_dim
,
patch_size
*
patch_size
*
self
.
out_channels
,
bias
=
True
)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def
fuse_qkv_projections
(
self
):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self
.
original_attn_processors
=
None
for
_
,
attn_processor
in
self
.
attn_processors
.
items
():
if
"Added"
in
str
(
attn_processor
.
__class__
.
__name__
):
raise
ValueError
(
"`fuse_qkv_projections()` is not supported for models having added KV projections."
)
self
.
original_attn_processors
=
self
.
attn_processors
for
module
in
self
.
modules
():
if
isinstance
(
module
,
Attention
):
module
.
fuse_projections
(
fuse
=
True
)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def
unfuse_qkv_projections
(
self
):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if
self
.
original_attn_processors
is
not
None
:
self
.
set_attn_processor
(
self
.
original_attn_processors
)
@
property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def
attn_processors
(
self
)
->
Dict
[
str
,
AttentionProcessor
]:
r
"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors
=
{}
def
fn_recursive_add_processors
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processors
:
Dict
[
str
,
AttentionProcessor
]):
if
hasattr
(
module
,
"get_processor"
):
processors
[
f
"
{
name
}
.processor"
]
=
module
.
get_processor
(
return_deprecated_lora
=
True
)
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_add_processors
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processors
)
return
processors
for
name
,
module
in
self
.
named_children
():
fn_recursive_add_processors
(
name
,
module
,
processors
)
return
processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def
set_attn_processor
(
self
,
processor
:
Union
[
AttentionProcessor
,
Dict
[
str
,
AttentionProcessor
]]):
r
"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count
=
len
(
self
.
attn_processors
.
keys
())
if
isinstance
(
processor
,
dict
)
and
len
(
processor
)
!=
count
:
raise
ValueError
(
f
"A dict of processors was passed, but the number of processors
{
len
(
processor
)
}
does not match the"
f
" number of attention layers:
{
count
}
. Please make sure to pass
{
count
}
processor classes."
)
def
fn_recursive_attn_processor
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processor
):
if
hasattr
(
module
,
"set_processor"
):
if
not
isinstance
(
processor
,
dict
):
module
.
set_processor
(
processor
)
else
:
module
.
set_processor
(
processor
.
pop
(
f
"
{
name
}
.processor"
))
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_attn_processor
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processor
)
for
name
,
module
in
self
.
named_children
():
fn_recursive_attn_processor
(
name
,
module
,
processor
)
def
set_default_attn_processor
(
self
):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self
.
set_attn_processor
(
HunyuanAttnProcessor2_0
())
def
forward
(
def
forward
(
self
,
self
,
hidden_states
,
hidden_states
,
...
...
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
View file @
14f7b545
...
@@ -228,6 +228,40 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -228,6 +228,40 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff
=
np
.
abs
(
to_np
(
output
)
-
to_np
(
output_loaded
)).
max
()
max_diff
=
np
.
abs
(
to_np
(
output
)
-
to_np
(
output_loaded
)).
max
()
self
.
assertLess
(
max_diff
,
1e-4
)
self
.
assertLess
(
max_diff
,
1e-4
)
def
test_fused_qkv_projections
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
[
"return_dict"
]
=
False
image
=
pipe
(
**
inputs
)[
0
]
original_image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
pipe
.
transformer
.
fuse_qkv_projections
()
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
[
"return_dict"
]
=
False
image_fused
=
pipe
(
**
inputs
)[
0
]
image_slice_fused
=
image_fused
[
0
,
-
3
:,
-
3
:,
-
1
]
pipe
.
transformer
.
unfuse_qkv_projections
()
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
[
"return_dict"
]
=
False
image_disabled
=
pipe
(
**
inputs
)[
0
]
image_slice_disabled
=
image_disabled
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
np
.
allclose
(
original_image_slice
,
image_slice_fused
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Fusion of QKV projections shouldn't affect the outputs."
assert
np
.
allclose
(
image_slice_fused
,
image_slice_disabled
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert
np
.
allclose
(
original_image_slice
,
image_slice_disabled
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Original outputs should match when fused QKV projections are disabled."
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment