Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3be67060
Unverified
Commit
3be67060
authored
Mar 18, 2025
by
Aryan
Committed by
GitHub
Mar 18, 2025
Browse files
Fix Group offloading behaviour when using streams (#11097)
* update * update
parent
cb1b8b21
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
10 deletions
+17
-10
src/diffusers/hooks/group_offloading.py
src/diffusers/hooks/group_offloading.py
+17
-10
No files found.
src/diffusers/hooks/group_offloading.py
View file @
3be67060
...
@@ -181,6 +181,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
...
@@ -181,6 +181,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
self
.
_layer_execution_tracker_module_names
=
set
()
self
.
_layer_execution_tracker_module_names
=
set
()
def
initialize_hook
(
self
,
module
):
def
initialize_hook
(
self
,
module
):
def
make_execution_order_update_callback
(
current_name
,
current_submodule
):
def
callback
():
logger
.
debug
(
f
"Adding
{
current_name
}
to the execution order"
)
self
.
execution_order
.
append
((
current_name
,
current_submodule
))
return
callback
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
# layers are executed during the forward pass.
# layers are executed during the forward pass.
...
@@ -192,14 +199,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
...
@@ -192,14 +199,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
group_offloading_hook
=
registry
.
get_hook
(
_GROUP_OFFLOADING
)
group_offloading_hook
=
registry
.
get_hook
(
_GROUP_OFFLOADING
)
if
group_offloading_hook
is
not
None
:
if
group_offloading_hook
is
not
None
:
# For the first forward pass, we have to load in a blocking manner
def
make_execution_order_update_callback
(
current_name
,
current_submodule
):
group_offloading_hook
.
group
.
non_blocking
=
False
def
callback
():
logger
.
debug
(
f
"Adding
{
current_name
}
to the execution order"
)
self
.
execution_order
.
append
((
current_name
,
current_submodule
))
return
callback
layer_tracker_hook
=
LayerExecutionTrackerHook
(
make_execution_order_update_callback
(
name
,
submodule
))
layer_tracker_hook
=
LayerExecutionTrackerHook
(
make_execution_order_update_callback
(
name
,
submodule
))
registry
.
register_hook
(
layer_tracker_hook
,
_LAYER_EXECUTION_TRACKER
)
registry
.
register_hook
(
layer_tracker_hook
,
_LAYER_EXECUTION_TRACKER
)
self
.
_layer_execution_tracker_module_names
.
add
(
name
)
self
.
_layer_execution_tracker_module_names
.
add
(
name
)
...
@@ -229,6 +230,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
...
@@ -229,6 +230,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the layer execution tracker hooks from the submodules
# Remove the layer execution tracker hooks from the submodules
base_module_registry
=
module
.
_diffusers_hook
base_module_registry
=
module
.
_diffusers_hook
registries
=
[
submodule
.
_diffusers_hook
for
_
,
submodule
in
self
.
execution_order
]
registries
=
[
submodule
.
_diffusers_hook
for
_
,
submodule
in
self
.
execution_order
]
group_offloading_hooks
=
[
registry
.
get_hook
(
_GROUP_OFFLOADING
)
for
registry
in
registries
]
for
i
in
range
(
num_executed
):
for
i
in
range
(
num_executed
):
registries
[
i
].
remove_hook
(
_LAYER_EXECUTION_TRACKER
,
recurse
=
False
)
registries
[
i
].
remove_hook
(
_LAYER_EXECUTION_TRACKER
,
recurse
=
False
)
...
@@ -236,8 +238,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
...
@@ -236,8 +238,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
base_module_registry
.
remove_hook
(
_LAZY_PREFETCH_GROUP_OFFLOADING
,
recurse
=
False
)
base_module_registry
.
remove_hook
(
_LAZY_PREFETCH_GROUP_OFFLOADING
,
recurse
=
False
)
# Apply lazy prefetching by setting required attributes
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
group_offloading_hooks
=
[
registry
.
get_hook
(
_GROUP_OFFLOADING
)
for
registry
in
registries
]
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
# see the benefits of prefetching.
for
hook
in
group_offloading_hooks
:
hook
.
group
.
non_blocking
=
True
# Set required attributes for prefetching
if
num_executed
>
0
:
if
num_executed
>
0
:
base_module_group_offloading_hook
=
base_module_registry
.
get_hook
(
_GROUP_OFFLOADING
)
base_module_group_offloading_hook
=
base_module_registry
.
get_hook
(
_GROUP_OFFLOADING
)
base_module_group_offloading_hook
.
next_group
=
group_offloading_hooks
[
0
].
group
base_module_group_offloading_hook
.
next_group
=
group_offloading_hooks
[
0
].
group
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment