Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
14f7b545
Unverified
Commit
14f7b545
authored
Jun 05, 2024
by
Sayak Paul
Committed by
GitHub
Jun 05, 2024
Browse files
[Hunyuan DiT] feat: enable fusing qkv projections when doing attention (#8396)
* feat: introduce qkv fusion for Hunyuan * fix copies
parent
07cd2004
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
140 additions
and
2 deletions
+140
-2
src/diffusers/models/transformers/hunyuan_transformer_2d.py
src/diffusers/models/transformers/hunyuan_transformer_2d.py
+106
-2
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
+34
-0
No files found.
src/diffusers/models/transformers/hunyuan_transformer_2d.py
View file @
14f7b545
...
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
from
typing
import
Dict
,
Optional
,
Union
import
torch
import
torch.nn.functional
as
F
...
...
@@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from
...utils
import
logging
from
...utils.torch_utils
import
maybe_allow_in_graph
from
..attention
import
FeedForward
from
..attention_processor
import
Attention
,
HunyuanAttnProcessor2_0
from
..attention_processor
import
Attention
,
AttentionProcessor
,
HunyuanAttnProcessor2_0
from
..embeddings
import
(
HunyuanCombinedTimestepTextSizeStyleEmbedding
,
PatchEmbed
,
...
...
@@ -321,6 +321,110 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
self
.
norm_out
=
AdaLayerNormContinuous
(
self
.
inner_dim
,
self
.
inner_dim
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
proj_out
=
nn
.
Linear
(
self
.
inner_dim
,
patch_size
*
patch_size
*
self
.
out_channels
,
bias
=
True
)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def
fuse_qkv_projections
(
self
):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self
.
original_attn_processors
=
None
for
_
,
attn_processor
in
self
.
attn_processors
.
items
():
if
"Added"
in
str
(
attn_processor
.
__class__
.
__name__
):
raise
ValueError
(
"`fuse_qkv_projections()` is not supported for models having added KV projections."
)
self
.
original_attn_processors
=
self
.
attn_processors
for
module
in
self
.
modules
():
if
isinstance
(
module
,
Attention
):
module
.
fuse_projections
(
fuse
=
True
)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def
unfuse_qkv_projections
(
self
):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if
self
.
original_attn_processors
is
not
None
:
self
.
set_attn_processor
(
self
.
original_attn_processors
)
@
property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def
attn_processors
(
self
)
->
Dict
[
str
,
AttentionProcessor
]:
r
"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors
=
{}
def
fn_recursive_add_processors
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processors
:
Dict
[
str
,
AttentionProcessor
]):
if
hasattr
(
module
,
"get_processor"
):
processors
[
f
"
{
name
}
.processor"
]
=
module
.
get_processor
(
return_deprecated_lora
=
True
)
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_add_processors
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processors
)
return
processors
for
name
,
module
in
self
.
named_children
():
fn_recursive_add_processors
(
name
,
module
,
processors
)
return
processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def
set_attn_processor
(
self
,
processor
:
Union
[
AttentionProcessor
,
Dict
[
str
,
AttentionProcessor
]]):
r
"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count
=
len
(
self
.
attn_processors
.
keys
())
if
isinstance
(
processor
,
dict
)
and
len
(
processor
)
!=
count
:
raise
ValueError
(
f
"A dict of processors was passed, but the number of processors
{
len
(
processor
)
}
does not match the"
f
" number of attention layers:
{
count
}
. Please make sure to pass
{
count
}
processor classes."
)
def
fn_recursive_attn_processor
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processor
):
if
hasattr
(
module
,
"set_processor"
):
if
not
isinstance
(
processor
,
dict
):
module
.
set_processor
(
processor
)
else
:
module
.
set_processor
(
processor
.
pop
(
f
"
{
name
}
.processor"
))
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_attn_processor
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processor
)
for
name
,
module
in
self
.
named_children
():
fn_recursive_attn_processor
(
name
,
module
,
processor
)
def
set_default_attn_processor
(
self
):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self
.
set_attn_processor
(
HunyuanAttnProcessor2_0
())
def
forward
(
self
,
hidden_states
,
...
...
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
View file @
14f7b545
...
...
@@ -228,6 +228,40 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff
=
np
.
abs
(
to_np
(
output
)
-
to_np
(
output_loaded
)).
max
()
self
.
assertLess
(
max_diff
,
1e-4
)
def
test_fused_qkv_projections
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
[
"return_dict"
]
=
False
image
=
pipe
(
**
inputs
)[
0
]
original_image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
pipe
.
transformer
.
fuse_qkv_projections
()
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
[
"return_dict"
]
=
False
image_fused
=
pipe
(
**
inputs
)[
0
]
image_slice_fused
=
image_fused
[
0
,
-
3
:,
-
3
:,
-
1
]
pipe
.
transformer
.
unfuse_qkv_projections
()
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
[
"return_dict"
]
=
False
image_disabled
=
pipe
(
**
inputs
)[
0
]
image_slice_disabled
=
image_disabled
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
np
.
allclose
(
original_image_slice
,
image_slice_fused
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Fusion of QKV projections shouldn't affect the outputs."
assert
np
.
allclose
(
image_slice_fused
,
image_slice_disabled
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert
np
.
allclose
(
original_image_slice
,
image_slice_disabled
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Original outputs should match when fused QKV projections are disabled."
@
slow
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment