Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
813d42cc
Unverified
Commit
813d42cc
authored
Mar 18, 2025
by
Aryan
Committed by
GitHub
Mar 18, 2025
Browse files
Group offloading improvements (#11094)
update
parent
b4d7e9c6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
7 deletions
+23
-7
src/diffusers/hooks/group_offloading.py
src/diffusers/hooks/group_offloading.py
+23
-7
No files found.
src/diffusers/hooks/group_offloading.py
View file @
813d42cc
...
@@ -83,7 +83,10 @@ class ModuleGroup:
...
@@ -83,7 +83,10 @@ class ModuleGroup:
with
context
:
with
context
:
for
group_module
in
self
.
modules
:
for
group_module
in
self
.
modules
:
group_module
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
for
param
in
group_module
.
parameters
():
param
.
data
=
param
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
for
buffer
in
group_module
.
buffers
():
buffer
.
data
=
buffer
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
if
self
.
parameters
is
not
None
:
if
self
.
parameters
is
not
None
:
for
param
in
self
.
parameters
:
for
param
in
self
.
parameters
:
param
.
data
=
param
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
param
.
data
=
param
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
...
@@ -98,6 +101,12 @@ class ModuleGroup:
...
@@ -98,6 +101,12 @@ class ModuleGroup:
for
group_module
in
self
.
modules
:
for
group_module
in
self
.
modules
:
for
param
in
group_module
.
parameters
():
for
param
in
group_module
.
parameters
():
param
.
data
=
self
.
cpu_param_dict
[
param
]
param
.
data
=
self
.
cpu_param_dict
[
param
]
if
self
.
parameters
is
not
None
:
for
param
in
self
.
parameters
:
param
.
data
=
self
.
cpu_param_dict
[
param
]
if
self
.
buffers
is
not
None
:
for
buffer
in
self
.
buffers
:
buffer
.
data
=
self
.
cpu_param_dict
[
buffer
]
else
:
else
:
for
group_module
in
self
.
modules
:
for
group_module
in
self
.
modules
:
group_module
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
group_module
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
...
@@ -387,9 +396,7 @@ def _apply_group_offloading_block_level(
...
@@ -387,9 +396,7 @@ def _apply_group_offloading_block_level(
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict
=
None
cpu_param_dict
=
None
if
stream
is
not
None
:
if
stream
is
not
None
:
for
param
in
module
.
parameters
():
cpu_param_dict
=
_get_pinned_cpu_param_dict
(
module
)
param
.
data
=
param
.
data
.
cpu
().
pin_memory
()
cpu_param_dict
=
{
param
:
param
.
data
for
param
in
module
.
parameters
()}
# Create module groups for ModuleList and Sequential blocks
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading
=
set
()
modules_with_group_offloading
=
set
()
...
@@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level(
...
@@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level(
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict
=
None
cpu_param_dict
=
None
if
stream
is
not
None
:
if
stream
is
not
None
:
for
param
in
module
.
parameters
():
cpu_param_dict
=
_get_pinned_cpu_param_dict
(
module
)
param
.
data
=
param
.
data
.
cpu
().
pin_memory
()
cpu_param_dict
=
{
param
:
param
.
data
for
param
in
module
.
parameters
()}
# Create module groups for leaf modules and apply group offloading hooks
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading
=
set
()
modules_with_group_offloading
=
set
()
...
@@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook(
...
@@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook(
registry
.
register_hook
(
lazy_prefetch_hook
,
_LAZY_PREFETCH_GROUP_OFFLOADING
)
registry
.
register_hook
(
lazy_prefetch_hook
,
_LAZY_PREFETCH_GROUP_OFFLOADING
)
def
_get_pinned_cpu_param_dict
(
module
:
torch
.
nn
.
Module
)
->
Dict
[
torch
.
nn
.
Parameter
,
torch
.
Tensor
]:
cpu_param_dict
=
{}
for
param
in
module
.
parameters
():
param
.
data
=
param
.
data
.
cpu
().
pin_memory
()
cpu_param_dict
[
param
]
=
param
.
data
for
buffer
in
module
.
buffers
():
buffer
.
data
=
buffer
.
data
.
cpu
().
pin_memory
()
cpu_param_dict
[
buffer
]
=
buffer
.
data
return
cpu_param_dict
def
_gather_parameters_with_no_group_offloading_parent
(
def
_gather_parameters_with_no_group_offloading_parent
(
module
:
torch
.
nn
.
Module
,
modules_with_group_offloading
:
Set
[
str
]
module
:
torch
.
nn
.
Module
,
modules_with_group_offloading
:
Set
[
str
]
)
->
List
[
torch
.
nn
.
Parameter
]:
)
->
List
[
torch
.
nn
.
Parameter
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment