Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
46a0c6aa
Unverified
Commit
46a0c6aa
authored
Aug 14, 2025
by
Sayak Paul
Committed by
GitHub
Aug 14, 2025
Browse files
feat: cuda device_map for pipelines. (#12122)
* feat: cuda device_map for pipelines. * up * up * empty * up
parent
421ee07e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
7 deletions
+38
-7
src/diffusers/pipelines/pipeline_loading_utils.py
src/diffusers/pipelines/pipeline_loading_utils.py
+3
-0
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+10
-7
src/diffusers/utils/torch_utils.py
src/diffusers/utils/torch_utils.py
+2
-0
tests/pipelines/test_pipelines_common.py
tests/pipelines/test_pipelines_common.py
+23
-0
No files found.
src/diffusers/pipelines/pipeline_loading_utils.py
View file @
46a0c6aa
...
...
@@ -613,6 +613,9 @@ def _assign_components_to_devices(
def
_get_final_device_map
(
device_map
,
pipeline_class
,
passed_class_obj
,
init_dict
,
library
,
max_memory
,
**
kwargs
):
# TODO: seperate out different device_map methods when it gets to it.
if
device_map
!=
"balanced"
:
return
device_map
# To avoid circular import problem.
from
diffusers
import
pipelines
...
...
src/diffusers/pipelines/pipeline_utils.py
View file @
46a0c6aa
...
...
@@ -108,7 +108,7 @@ LIBRARIES = []
for
library
in
LOADABLE_CLASSES
:
LIBRARIES
.
append
(
library
)
SUPPORTED_DEVICE_MAP
=
[
"balanced"
]
SUPPORTED_DEVICE_MAP
=
[
"balanced"
]
+
[
get_device
()]
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -988,12 +988,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
_maybe_warn_for_wrong_component_in_quant_config
(
init_dict
,
quantization_config
)
for
name
,
(
library_name
,
class_name
)
in
logging
.
tqdm
(
init_dict
.
items
(),
desc
=
"Loading pipeline components..."
):
# 7.1 device_map shenanigans
if
final_device_map
is
not
None
and
len
(
final_device_map
)
>
0
:
if
final_device_map
is
not
None
:
if
isinstance
(
final_device_map
,
dict
)
and
len
(
final_device_map
)
>
0
:
component_device
=
final_device_map
.
get
(
name
,
None
)
if
component_device
is
not
None
:
current_device_map
=
{
""
:
component_device
}
else
:
current_device_map
=
None
elif
isinstance
(
final_device_map
,
str
):
current_device_map
=
final_device_map
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name
=
class_name
[
4
:]
if
class_name
.
startswith
(
"Flax"
)
else
class_name
...
...
src/diffusers/utils/torch_utils.py
View file @
46a0c6aa
...
...
@@ -15,6 +15,7 @@
PyTorch utilities: Utilities related to PyTorch
"""
import
functools
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
.
import
logging
...
...
@@ -168,6 +169,7 @@ def get_torch_cuda_device_capability():
return
None
@
functools
.
lru_cache
def
get_device
():
if
torch
.
cuda
.
is_available
():
return
"cuda"
...
...
tests/pipelines/test_pipelines_common.py
View file @
46a0c6aa
...
...
@@ -2339,6 +2339,29 @@ class PipelineTesterMixin:
f
"Component '
{
name
}
' has dtype
{
component
.
dtype
}
but expected
{
expected_dtype
}
"
,
)
@
require_torch_accelerator
def
test_pipeline_with_accelerator_device_map
(
self
,
expected_max_difference
=
1e-4
):
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
torch
.
manual_seed
(
0
)
inputs
=
self
.
get_dummy_inputs
(
torch_device
)
inputs
[
"generator"
]
=
torch
.
manual_seed
(
0
)
out
=
pipe
(
**
inputs
)[
0
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
pipe
.
save_pretrained
(
tmpdir
)
loaded_pipe
=
self
.
pipeline_class
.
from_pretrained
(
tmpdir
,
device_map
=
torch_device
)
for
component
in
loaded_pipe
.
components
.
values
():
if
hasattr
(
component
,
"set_default_attn_processor"
):
component
.
set_default_attn_processor
()
inputs
[
"generator"
]
=
torch
.
manual_seed
(
0
)
loaded_out
=
loaded_pipe
(
**
inputs
)[
0
]
max_diff
=
np
.
abs
(
to_np
(
out
)
-
to_np
(
loaded_out
)).
max
()
self
.
assertLess
(
max_diff
,
expected_max_difference
)
@
is_staging_test
class
PipelinePushToHubTester
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment