Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
a8ad6664
Unverified
Commit
a8ad6664
authored
Jun 05, 2024
by
Sayak Paul
Committed by
GitHub
Jun 05, 2024
Browse files
[Hunyuan] feat: support chunked ff. (#8397)
feat: support chunked ff.
parent
14f7b545
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
0 deletions
+63
-0
src/diffusers/models/transformers/hunyuan_transformer_2d.py
src/diffusers/models/transformers/hunyuan_transformer_2d.py
+43
-0
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
+20
-0
No files found.
src/diffusers/models/transformers/hunyuan_transformer_2d.py
View file @
a8ad6664
...
...
@@ -166,6 +166,7 @@ class HunyuanDiTBlock(nn.Module):
self
.
_chunk_size
=
None
self
.
_chunk_dim
=
0
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
def
set_chunk_feed_forward
(
self
,
chunk_size
:
Optional
[
int
],
dim
:
int
=
0
):
# Sets chunk feed-forward
self
.
_chunk_size
=
chunk_size
...
...
@@ -529,3 +530,45 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
if
not
return_dict
:
return
(
output
,)
return
Transformer2DModelOutput
(
sample
=
output
)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def
enable_forward_chunking
(
self
,
chunk_size
:
Optional
[
int
]
=
None
,
dim
:
int
=
0
)
->
None
:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if
dim
not
in
[
0
,
1
]:
raise
ValueError
(
f
"Make sure to set `dim` to either 0 or 1, not
{
dim
}
"
)
# By default chunk size is 1
chunk_size
=
chunk_size
or
1
def
fn_recursive_feed_forward
(
module
:
torch
.
nn
.
Module
,
chunk_size
:
int
,
dim
:
int
):
if
hasattr
(
module
,
"set_chunk_feed_forward"
):
module
.
set_chunk_feed_forward
(
chunk_size
=
chunk_size
,
dim
=
dim
)
for
child
in
module
.
children
():
fn_recursive_feed_forward
(
child
,
chunk_size
,
dim
)
for
module
in
self
.
children
():
fn_recursive_feed_forward
(
module
,
chunk_size
,
dim
)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def
disable_forward_chunking
(
self
):
def
fn_recursive_feed_forward
(
module
:
torch
.
nn
.
Module
,
chunk_size
:
int
,
dim
:
int
):
if
hasattr
(
module
,
"set_chunk_feed_forward"
):
module
.
set_chunk_feed_forward
(
chunk_size
=
chunk_size
,
dim
=
dim
)
for
child
in
module
.
children
():
fn_recursive_feed_forward
(
child
,
chunk_size
,
dim
)
for
module
in
self
.
children
():
fn_recursive_feed_forward
(
module
,
None
,
0
)
tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
View file @
a8ad6664
...
...
@@ -228,6 +228,26 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff
=
np
.
abs
(
to_np
(
output
)
-
to_np
(
output_loaded
)).
max
()
self
.
assertLess
(
max_diff
,
1e-4
)
def
test_feed_forward_chunking
(
self
):
device
=
"cpu"
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice_no_chunking
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
pipe
.
transformer
.
enable_forward_chunking
(
chunk_size
=
1
,
dim
=
0
)
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image_slice_chunking
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
max_diff
=
np
.
abs
(
to_np
(
image_slice_no_chunking
)
-
to_np
(
image_slice_chunking
)).
max
()
self
.
assertLess
(
max_diff
,
1e-4
)
def
test_fused_qkv_projections
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
components
=
self
.
get_dummy_components
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment