Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
6cef71de
Unverified
Commit
6cef71de
authored
Apr 23, 2025
by
Aryan
Committed by
GitHub
Apr 23, 2025
Browse files
Fix group offloading with block_level and use_stream=True (#11375)
* fix * add tests * add message check
parent
026507c0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
9 deletions
+64
-9
src/diffusers/hooks/group_offloading.py
src/diffusers/hooks/group_offloading.py
+9
-9
tests/hooks/test_group_offloading.py
tests/hooks/test_group_offloading.py
+55
-0
No files found.
src/diffusers/hooks/group_offloading.py
View file @
6cef71de
...
...
@@ -57,7 +57,7 @@ class ModuleGroup:
non_blocking
:
bool
=
False
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
record_stream
:
Optional
[
bool
]
=
False
,
low_cpu_mem_usage
=
False
,
low_cpu_mem_usage
:
bool
=
False
,
onload_self
:
bool
=
True
,
)
->
None
:
self
.
modules
=
modules
...
...
@@ -498,6 +498,8 @@ def _apply_group_offloading_block_level(
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
if
stream
is
not
None
and
num_blocks_per_group
!=
1
:
raise
ValueError
(
f
"Using streams is only supported for num_blocks_per_group=1. Got
{
num_blocks_per_group
=
}
."
)
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading
=
set
()
...
...
@@ -521,7 +523,7 @@ def _apply_group_offloading_block_level(
stream
=
stream
,
record_stream
=
record_stream
,
low_cpu_mem_usage
=
low_cpu_mem_usage
,
onload_self
=
stream
is
Non
e
,
onload_self
=
Tru
e
,
)
matched_module_groups
.
append
(
group
)
for
j
in
range
(
i
,
i
+
len
(
current_modules
)):
...
...
@@ -529,12 +531,8 @@ def _apply_group_offloading_block_level(
# Apply group offloading hooks to the module groups
for
i
,
group
in
enumerate
(
matched_module_groups
):
next_group
=
(
matched_module_groups
[
i
+
1
]
if
i
+
1
<
len
(
matched_module_groups
)
and
stream
is
not
None
else
None
)
for
group_module
in
group
.
modules
:
_apply_group_offloading_hook
(
group_module
,
group
,
ne
xt_group
)
_apply_group_offloading_hook
(
group_module
,
group
,
No
ne
)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
...
...
@@ -560,8 +558,10 @@ def _apply_group_offloading_block_level(
record_stream
=
False
,
onload_self
=
True
,
)
next_group
=
matched_module_groups
[
0
]
if
len
(
matched_module_groups
)
>
0
else
None
_apply_group_offloading_hook
(
module
,
unmatched_group
,
next_group
)
if
stream
is
None
:
_apply_group_offloading_hook
(
module
,
unmatched_group
,
None
)
else
:
_apply_lazy_group_offloading_hook
(
module
,
unmatched_group
,
None
)
def
_apply_group_offloading_leaf_level
(
...
...
tests/hooks/test_group_offloading.py
View file @
6cef71de
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
import
gc
import
unittest
...
...
@@ -20,6 +21,7 @@ import torch
from
diffusers.models
import
ModelMixin
from
diffusers.pipelines.pipeline_utils
import
DiffusionPipeline
from
diffusers.utils
import
get_logger
from
diffusers.utils.import_utils
import
compare_versions
from
diffusers.utils.testing_utils
import
require_torch_gpu
,
torch_device
...
...
@@ -58,6 +60,39 @@ class DummyModel(ModelMixin):
return
x
# This model implementation contains one type of block (single_blocks) instantiated before another type of block (double_blocks).
# The invocation order of these blocks, however, is first the double_blocks and then the single_blocks.
# With group offloading implementation before https://github.com/huggingface/diffusers/pull/11375, such a modeling implementation
# would result in a device mismatch error because of the assumptions made by the code. The failure case occurs when using:
# offload_type="block_level", num_blocks_per_group=2, use_stream=True
# Post the linked PR, the implementation will work as expected.
class
DummyModelWithMultipleBlocks
(
ModelMixin
):
def
__init__
(
self
,
in_features
:
int
,
hidden_features
:
int
,
out_features
:
int
,
num_layers
:
int
,
num_single_layers
:
int
)
->
None
:
super
().
__init__
()
self
.
linear_1
=
torch
.
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
activation
=
torch
.
nn
.
ReLU
()
self
.
single_blocks
=
torch
.
nn
.
ModuleList
(
[
DummyBlock
(
hidden_features
,
hidden_features
,
hidden_features
)
for
_
in
range
(
num_single_layers
)]
)
self
.
double_blocks
=
torch
.
nn
.
ModuleList
(
[
DummyBlock
(
hidden_features
,
hidden_features
,
hidden_features
)
for
_
in
range
(
num_layers
)]
)
self
.
linear_2
=
torch
.
nn
.
Linear
(
hidden_features
,
out_features
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
linear_1
(
x
)
x
=
self
.
activation
(
x
)
for
block
in
self
.
double_blocks
:
x
=
block
(
x
)
for
block
in
self
.
single_blocks
:
x
=
block
(
x
)
x
=
self
.
linear_2
(
x
)
return
x
class
DummyPipeline
(
DiffusionPipeline
):
model_cpu_offload_seq
=
"model"
...
...
@@ -212,3 +247,23 @@ class GroupOffloadTests(unittest.TestCase):
pipe
.
enable_sequential_cpu_offload
()
with
self
.
assertRaisesRegex
(
ValueError
,
"Cannot apply group offloading"
):
pipe
.
model
.
enable_group_offload
(
torch_device
,
offload_type
=
"block_level"
,
num_blocks_per_group
=
3
)
def
test_block_level_stream_with_invocation_order_different_from_initialization_order
(
self
):
if
torch
.
device
(
torch_device
).
type
!=
"cuda"
:
return
model
=
DummyModelWithMultipleBlocks
(
in_features
=
self
.
in_features
,
hidden_features
=
self
.
hidden_features
,
out_features
=
self
.
out_features
,
num_layers
=
self
.
num_layers
,
num_single_layers
=
self
.
num_layers
+
1
,
)
model
.
enable_group_offload
(
torch_device
,
offload_type
=
"block_level"
,
num_blocks_per_group
=
1
,
use_stream
=
True
)
context
=
contextlib
.
nullcontext
()
if
compare_versions
(
"diffusers"
,
"<="
,
"0.33.0"
):
# Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
context
=
self
.
assertRaisesRegex
(
RuntimeError
,
"Expected all tensors to be on the same device"
)
with
context
:
model
(
self
.
input
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment