Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a8ad6664
Unverified
Commit
a8ad6664
authored
Jun 05, 2024
by
Sayak Paul
Committed by
GitHub
Jun 05, 2024
Browse files
[Hunyuan] feat: support chunked ff. (#8397)
feat: support chunked ff.
parent
14f7b545
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
0 deletions
+63
-0
src/diffusers/models/transformers/hunyuan_transformer_2d.py
src/diffusers/models/transformers/hunyuan_transformer_2d.py
+43
-0
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
+20
-0
No files found.
src/diffusers/models/transformers/hunyuan_transformer_2d.py
View file @
a8ad6664
...
...
@@ -166,6 +166,7 @@ class HunyuanDiTBlock(nn.Module):
self
.
_chunk_size
=
None
self
.
_chunk_dim
=
0
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
def
set_chunk_feed_forward
(
self
,
chunk_size
:
Optional
[
int
],
dim
:
int
=
0
):
# Sets chunk feed-forward
self
.
_chunk_size
=
chunk_size
...
...
@@ -529,3 +530,45 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
if
not
return_dict
:
return
(
output
,)
return
Transformer2DModelOutput
(
sample
=
output
)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def
enable_forward_chunking
(
self
,
chunk_size
:
Optional
[
int
]
=
None
,
dim
:
int
=
0
)
->
None
:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if
dim
not
in
[
0
,
1
]:
raise
ValueError
(
f
"Make sure to set `dim` to either 0 or 1, not
{
dim
}
"
)
# By default chunk size is 1
chunk_size
=
chunk_size
or
1
def
fn_recursive_feed_forward
(
module
:
torch
.
nn
.
Module
,
chunk_size
:
int
,
dim
:
int
):
if
hasattr
(
module
,
"set_chunk_feed_forward"
):
module
.
set_chunk_feed_forward
(
chunk_size
=
chunk_size
,
dim
=
dim
)
for
child
in
module
.
children
():
fn_recursive_feed_forward
(
child
,
chunk_size
,
dim
)
for
module
in
self
.
children
():
fn_recursive_feed_forward
(
module
,
chunk_size
,
dim
)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def
disable_forward_chunking
(
self
):
def
fn_recursive_feed_forward
(
module
:
torch
.
nn
.
Module
,
chunk_size
:
int
,
dim
:
int
):
if
hasattr
(
module
,
"set_chunk_feed_forward"
):
module
.
set_chunk_feed_forward
(
chunk_size
=
chunk_size
,
dim
=
dim
)
for
child
in
module
.
children
():
fn_recursive_feed_forward
(
child
,
chunk_size
,
dim
)
for
module
in
self
.
children
():
fn_recursive_feed_forward
(
module
,
None
,
0
)
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
View file @
a8ad6664
...
...
@@ -228,6 +228,26 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff
=
np
.
abs
(
to_np
(
output
)
-
to_np
(
output_loaded
)).
max
()
self
.
assertLess
(
max_diff
,
1e-4
)
def
test_feed_forward_chunking
(
self
):
device
=
"cpu"
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice_no_chunking
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
pipe
.
transformer
.
enable_forward_chunking
(
chunk_size
=
1
,
dim
=
0
)
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice_chunking
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
max_diff
=
np
.
abs
(
to_np
(
image_slice_no_chunking
)
-
to_np
(
image_slice_chunking
)).
max
()
self
.
assertLess
(
max_diff
,
1e-4
)
def
test_fused_qkv_projections
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
components
=
self
.
get_dummy_components
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment