Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
2c1ed50f
Unverified
Commit
2c1ed50f
authored
Mar 20, 2025
by
Dhruv Nair
Committed by
GitHub
Mar 20, 2025
Browse files
Provide option to reduce CPU RAM usage in Group Offload (#11106)
* update * update * clean up
parent
15ad97f7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
55 deletions
+93
-55
src/diffusers/hooks/group_offloading.py
src/diffusers/hooks/group_offloading.py
+84
-54
src/diffusers/models/modeling_utils.py
src/diffusers/models/modeling_utils.py
+9
-1
No files found.
src/diffusers/hooks/group_offloading.py
View file @
2c1ed50f
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
contextlib
import
nullcontext
from
contextlib
import
contextmanager
,
nullcontext
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
...
...
@@ -56,7 +56,7 @@ class ModuleGroup:
buffers
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
non_blocking
:
bool
=
False
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
cpu_param_dict
:
Optional
[
Dict
[
torch
.
nn
.
Parameter
,
torch
.
Tensor
]]
=
Non
e
,
low_cpu_mem_usage
=
Fals
e
,
onload_self
:
bool
=
True
,
)
->
None
:
self
.
modules
=
modules
...
...
@@ -64,15 +64,50 @@ class ModuleGroup:
self
.
onload_device
=
onload_device
self
.
offload_leader
=
offload_leader
self
.
onload_leader
=
onload_leader
self
.
parameters
=
parameters
self
.
buffers
=
buffers
self
.
parameters
=
parameters
or
[]
self
.
buffers
=
buffers
or
[]
self
.
non_blocking
=
non_blocking
or
stream
is
not
None
self
.
stream
=
stream
self
.
cpu_param_dict
=
cpu_param_dict
self
.
onload_self
=
onload_self
self
.
low_cpu_mem_usage
=
low_cpu_mem_usage
if
self
.
stream
is
not
None
and
self
.
cpu_param_dict
is
None
:
raise
ValueError
(
"cpu_param_dict must be provided when using stream for data transfer."
)
self
.
cpu_param_dict
=
self
.
_init_cpu_param_dict
()
def
_init_cpu_param_dict
(
self
):
cpu_param_dict
=
{}
if
self
.
stream
is
None
:
return
cpu_param_dict
for
module
in
self
.
modules
:
for
param
in
module
.
parameters
():
cpu_param_dict
[
param
]
=
param
.
data
.
cpu
()
if
self
.
low_cpu_mem_usage
else
param
.
data
.
cpu
().
pin_memory
()
for
buffer
in
module
.
buffers
():
cpu_param_dict
[
buffer
]
=
(
buffer
.
data
.
cpu
()
if
self
.
low_cpu_mem_usage
else
buffer
.
data
.
cpu
().
pin_memory
()
)
for
param
in
self
.
parameters
:
cpu_param_dict
[
param
]
=
param
.
data
.
cpu
()
if
self
.
low_cpu_mem_usage
else
param
.
data
.
cpu
().
pin_memory
()
for
buffer
in
self
.
buffers
:
cpu_param_dict
[
buffer
]
=
buffer
.
data
.
cpu
()
if
self
.
low_cpu_mem_usage
else
buffer
.
data
.
cpu
().
pin_memory
()
return
cpu_param_dict
@
contextmanager
def
_pinned_memory_tensors
(
self
):
pinned_dict
=
{}
try
:
for
param
,
tensor
in
self
.
cpu_param_dict
.
items
():
if
not
tensor
.
is_pinned
():
pinned_dict
[
param
]
=
tensor
.
pin_memory
()
else
:
pinned_dict
[
param
]
=
tensor
yield
pinned_dict
finally
:
pinned_dict
=
None
def
onload_
(
self
):
r
"""Onloads the group of modules to the onload_device."""
...
...
@@ -82,15 +117,30 @@ class ModuleGroup:
self
.
stream
.
synchronize
()
with
context
:
for
group_module
in
self
.
modules
:
for
param
in
group_module
.
parameters
():
param
.
data
=
param
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
for
buffer
in
group_module
.
buffers
():
buffer
.
data
=
buffer
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
if
self
.
parameters
is
not
None
:
if
self
.
stream
is
not
None
:
with
self
.
_pinned_memory_tensors
()
as
pinned_memory
:
for
group_module
in
self
.
modules
:
for
param
in
group_module
.
parameters
():
param
.
data
=
pinned_memory
[
param
].
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
for
buffer
in
group_module
.
buffers
():
buffer
.
data
=
pinned_memory
[
buffer
].
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
for
param
in
self
.
parameters
:
param
.
data
=
pinned_memory
[
param
].
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
for
buffer
in
self
.
buffers
:
buffer
.
data
=
pinned_memory
[
buffer
].
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
else
:
for
group_module
in
self
.
modules
:
for
param
in
group_module
.
parameters
():
param
.
data
=
param
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
for
buffer
in
group_module
.
buffers
():
buffer
.
data
=
buffer
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
for
param
in
self
.
parameters
:
param
.
data
=
param
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
if
self
.
buffers
is
not
None
:
for
buffer
in
self
.
buffers
:
buffer
.
data
=
buffer
.
data
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
...
...
@@ -101,21 +151,18 @@ class ModuleGroup:
for
group_module
in
self
.
modules
:
for
param
in
group_module
.
parameters
():
param
.
data
=
self
.
cpu_param_dict
[
param
]
if
self
.
parameters
is
not
None
:
for
param
in
self
.
parameters
:
param
.
data
=
self
.
cpu_param_dict
[
param
]
if
self
.
buffers
is
not
None
:
for
buffer
in
self
.
buffers
:
buffer
.
data
=
self
.
cpu_param_dict
[
buffer
]
for
param
in
self
.
parameters
:
param
.
data
=
self
.
cpu_param_dict
[
param
]
for
buffer
in
self
.
buffers
:
buffer
.
data
=
self
.
cpu_param_dict
[
buffer
]
else
:
for
group_module
in
self
.
modules
:
group_module
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
if
self
.
parameters
is
not
None
:
for
param
in
self
.
parameters
:
param
.
data
=
param
.
data
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
if
self
.
buffers
is
not
None
:
for
buffer
in
self
.
buffers
:
buffer
.
data
=
buffer
.
data
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
for
param
in
self
.
parameters
:
param
.
data
=
param
.
data
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
for
buffer
in
self
.
buffers
:
buffer
.
data
=
buffer
.
data
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
class
GroupOffloadingHook
(
ModelHook
):
...
...
@@ -284,6 +331,7 @@ def apply_group_offloading(
num_blocks_per_group
:
Optional
[
int
]
=
None
,
non_blocking
:
bool
=
False
,
use_stream
:
bool
=
False
,
low_cpu_mem_usage
=
False
,
)
->
None
:
r
"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
...
...
@@ -365,10 +413,12 @@ def apply_group_offloading(
raise
ValueError
(
"num_blocks_per_group must be provided when using offload_type='block_level'."
)
_apply_group_offloading_block_level
(
module
,
num_blocks_per_group
,
offload_device
,
onload_device
,
non_blocking
,
stream
module
,
num_blocks_per_group
,
offload_device
,
onload_device
,
non_blocking
,
stream
,
low_cpu_mem_usage
)
elif
offload_type
==
"leaf_level"
:
_apply_group_offloading_leaf_level
(
module
,
offload_device
,
onload_device
,
non_blocking
,
stream
)
_apply_group_offloading_leaf_level
(
module
,
offload_device
,
onload_device
,
non_blocking
,
stream
,
low_cpu_mem_usage
)
else
:
raise
ValueError
(
f
"Unsupported offload_type:
{
offload_type
}
"
)
...
...
@@ -380,6 +430,7 @@ def _apply_group_offloading_block_level(
onload_device
:
torch
.
device
,
non_blocking
:
bool
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
low_cpu_mem_usage
:
bool
=
False
,
)
->
None
:
r
"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
...
...
@@ -400,11 +451,6 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict
=
None
if
stream
is
not
None
:
cpu_param_dict
=
_get_pinned_cpu_param_dict
(
module
)
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading
=
set
()
unmatched_modules
=
[]
...
...
@@ -425,7 +471,7 @@ def _apply_group_offloading_block_level(
onload_leader
=
current_modules
[
0
],
non_blocking
=
non_blocking
,
stream
=
stream
,
cpu_param_dict
=
cpu_param_dict
,
low_cpu_mem_usage
=
low_cpu_mem_usage
,
onload_self
=
stream
is
None
,
)
matched_module_groups
.
append
(
group
)
...
...
@@ -462,7 +508,6 @@ def _apply_group_offloading_block_level(
buffers
=
buffers
,
non_blocking
=
False
,
stream
=
None
,
cpu_param_dict
=
None
,
onload_self
=
True
,
)
next_group
=
matched_module_groups
[
0
]
if
len
(
matched_module_groups
)
>
0
else
None
...
...
@@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level(
onload_device
:
torch
.
device
,
non_blocking
:
bool
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
low_cpu_mem_usage
:
bool
=
False
,
)
->
None
:
r
"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
...
...
@@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict
=
None
if
stream
is
not
None
:
cpu_param_dict
=
_get_pinned_cpu_param_dict
(
module
)
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading
=
set
()
for
name
,
submodule
in
module
.
named_modules
():
...
...
@@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level(
onload_leader
=
submodule
,
non_blocking
=
non_blocking
,
stream
=
stream
,
cpu_param_dict
=
cpu_param_dict
,
low_cpu_mem_usage
=
low_cpu_mem_usage
,
onload_self
=
True
,
)
_apply_group_offloading_hook
(
submodule
,
group
,
None
)
...
...
@@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level(
buffers
=
buffers
,
non_blocking
=
non_blocking
,
stream
=
stream
,
cpu_param_dict
=
cpu_param_dict
,
low_cpu_mem_usage
=
low_cpu_mem_usage
,
onload_self
=
True
,
)
_apply_group_offloading_hook
(
parent_module
,
group
,
None
)
...
...
@@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level(
buffers
=
None
,
non_blocking
=
False
,
stream
=
None
,
cpu_param_dict
=
Non
e
,
low_cpu_mem_usage
=
low_cpu_mem_usag
e
,
onload_self
=
True
,
)
_apply_lazy_group_offloading_hook
(
module
,
unmatched_group
,
None
)
...
...
@@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook(
registry
.
register_hook
(
lazy_prefetch_hook
,
_LAZY_PREFETCH_GROUP_OFFLOADING
)
def
_get_pinned_cpu_param_dict
(
module
:
torch
.
nn
.
Module
)
->
Dict
[
torch
.
nn
.
Parameter
,
torch
.
Tensor
]:
cpu_param_dict
=
{}
for
param
in
module
.
parameters
():
param
.
data
=
param
.
data
.
cpu
().
pin_memory
()
cpu_param_dict
[
param
]
=
param
.
data
for
buffer
in
module
.
buffers
():
buffer
.
data
=
buffer
.
data
.
cpu
().
pin_memory
()
cpu_param_dict
[
buffer
]
=
buffer
.
data
return
cpu_param_dict
def
_gather_parameters_with_no_group_offloading_parent
(
module
:
torch
.
nn
.
Module
,
modules_with_group_offloading
:
Set
[
str
]
)
->
List
[
torch
.
nn
.
Parameter
]:
...
...
src/diffusers/models/modeling_utils.py
View file @
2c1ed50f
...
...
@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group
:
Optional
[
int
]
=
None
,
non_blocking
:
bool
=
False
,
use_stream
:
bool
=
False
,
low_cpu_mem_usage
=
False
,
)
->
None
:
r
"""
Activates group offloading for the current model.
...
...
@@ -584,7 +585,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f
"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading
(
self
,
onload_device
,
offload_device
,
offload_type
,
num_blocks_per_group
,
non_blocking
,
use_stream
self
,
onload_device
,
offload_device
,
offload_type
,
num_blocks_per_group
,
non_blocking
,
use_stream
,
low_cpu_mem_usage
=
low_cpu_mem_usage
,
)
def
save_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment