Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b73c7383
Unverified
Commit
b73c7383
authored
Jul 15, 2025
by
Aryan
Committed by
GitHub
Jul 15, 2025
Browse files
Remove device synchronization when loading weights (#11927)
* update * make style
parent
06fd4277
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
6 additions
and
24 deletions
+6
-24
src/diffusers/loaders/single_file_model.py
src/diffusers/loaders/single_file_model.py
+1
-4
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+1
-7
src/diffusers/loaders/transformer_flux.py
src/diffusers/loaders/transformer_flux.py
+1
-3
src/diffusers/loaders/transformer_sd3.py
src/diffusers/loaders/transformer_sd3.py
+1
-3
src/diffusers/loaders/unet.py
src/diffusers/loaders/unet.py
+1
-3
src/diffusers/models/modeling_utils.py
src/diffusers/models/modeling_utils.py
+1
-4
No files found.
src/diffusers/loaders/single_file_model.py
View file @
b73c7383
...
...
@@ -24,7 +24,7 @@ from typing_extensions import Self
from
..
import
__version__
from
..quantizers
import
DiffusersAutoQuantizer
from
..utils
import
deprecate
,
is_accelerate_available
,
logging
from
..utils.torch_utils
import
device_synchronize
,
empty_device_cache
from
..utils.torch_utils
import
empty_device_cache
from
.single_file_utils
import
(
SingleFileComponentError
,
convert_animatediff_checkpoint_to_diffusers
,
...
...
@@ -431,10 +431,7 @@ class FromOriginalModelMixin:
keep_in_fp32_modules
=
keep_in_fp32_modules
,
unexpected_keys
=
unexpected_keys
,
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache
()
device_synchronize
()
else
:
_
,
unexpected_keys
=
model
.
load_state_dict
(
diffusers_format_checkpoint
,
strict
=
False
)
...
...
src/diffusers/loaders/single_file_utils.py
View file @
b73c7383
...
...
@@ -46,7 +46,7 @@ from ..utils import (
)
from
..utils.constants
import
DIFFUSERS_REQUEST_TIMEOUT
from
..utils.hub_utils
import
_get_model_file
from
..utils.torch_utils
import
device_synchronize
,
empty_device_cache
from
..utils.torch_utils
import
empty_device_cache
if
is_transformers_available
():
...
...
@@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
if
is_accelerate_available
():
load_model_dict_into_meta
(
model
,
diffusers_format_checkpoint
,
dtype
=
torch_dtype
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache
()
device_synchronize
()
else
:
model
.
load_state_dict
(
diffusers_format_checkpoint
,
strict
=
False
)
...
...
@@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
if
is_accelerate_available
():
load_model_dict_into_meta
(
model
,
diffusers_format_checkpoint
,
dtype
=
torch_dtype
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache
()
device_synchronize
()
else
:
model
.
load_state_dict
(
diffusers_format_checkpoint
)
...
...
src/diffusers/loaders/transformer_flux.py
View file @
b73c7383
...
...
@@ -19,7 +19,7 @@ from ..models.embeddings import (
)
from
..models.modeling_utils
import
_LOW_CPU_MEM_USAGE_DEFAULT
,
load_model_dict_into_meta
from
..utils
import
is_accelerate_available
,
is_torch_version
,
logging
from
..utils.torch_utils
import
device_synchronize
,
empty_device_cache
from
..utils.torch_utils
import
empty_device_cache
if
is_accelerate_available
():
...
...
@@ -82,7 +82,6 @@ class FluxTransformer2DLoadersMixin:
device_map
=
{
""
:
self
.
device
}
load_model_dict_into_meta
(
image_projection
,
updated_state_dict
,
device_map
=
device_map
,
dtype
=
self
.
dtype
)
empty_device_cache
()
device_synchronize
()
return
image_projection
...
...
@@ -158,7 +157,6 @@ class FluxTransformer2DLoadersMixin:
key_id
+=
1
empty_device_cache
()
device_synchronize
()
return
attn_procs
...
...
src/diffusers/loaders/transformer_sd3.py
View file @
b73c7383
...
...
@@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from
..models.embeddings
import
IPAdapterTimeImageProjection
from
..models.modeling_utils
import
_LOW_CPU_MEM_USAGE_DEFAULT
,
load_model_dict_into_meta
from
..utils
import
is_accelerate_available
,
is_torch_version
,
logging
from
..utils.torch_utils
import
device_synchronize
,
empty_device_cache
from
..utils.torch_utils
import
empty_device_cache
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -82,7 +82,6 @@ class SD3Transformer2DLoadersMixin:
)
empty_device_cache
()
device_synchronize
()
return
attn_procs
...
...
@@ -152,7 +151,6 @@ class SD3Transformer2DLoadersMixin:
device_map
=
{
""
:
self
.
device
}
load_model_dict_into_meta
(
image_proj
,
updated_state_dict
,
device_map
=
device_map
,
dtype
=
self
.
dtype
)
empty_device_cache
()
device_synchronize
()
return
image_proj
...
...
src/diffusers/loaders/unet.py
View file @
b73c7383
...
...
@@ -43,7 +43,7 @@ from ..utils import (
is_torch_version
,
logging
,
)
from
..utils.torch_utils
import
device_synchronize
,
empty_device_cache
from
..utils.torch_utils
import
empty_device_cache
from
.lora_base
import
_func_optionally_disable_offloading
from
.lora_pipeline
import
LORA_WEIGHT_NAME
,
LORA_WEIGHT_NAME_SAFE
,
TEXT_ENCODER_NAME
,
UNET_NAME
from
.utils
import
AttnProcsLayers
...
...
@@ -755,7 +755,6 @@ class UNet2DConditionLoadersMixin:
device_map
=
{
""
:
self
.
device
}
load_model_dict_into_meta
(
image_projection
,
updated_state_dict
,
device_map
=
device_map
,
dtype
=
self
.
dtype
)
empty_device_cache
()
device_synchronize
()
return
image_projection
...
...
@@ -854,7 +853,6 @@ class UNet2DConditionLoadersMixin:
key_id
+=
2
empty_device_cache
()
device_synchronize
()
return
attn_procs
...
...
src/diffusers/models/modeling_utils.py
View file @
b73c7383
...
...
@@ -62,7 +62,7 @@ from ..utils.hub_utils import (
load_or_create_model_card
,
populate_model_card
,
)
from
..utils.torch_utils
import
device_synchronize
,
empty_device_cache
from
..utils.torch_utils
import
empty_device_cache
from
.model_loading_utils
import
(
_caching_allocator_warmup
,
_determine_device_map
,
...
...
@@ -1540,10 +1540,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
assign_to_params_buffers
=
check_support_param_buffer_assignment
(
model
,
state_dict
)
error_msgs
+=
_load_state_dict_into_model
(
model
,
state_dict
,
assign_to_params_buffers
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache
()
device_synchronize
()
if
offload_index
is
not
None
and
len
(
offload_index
)
>
0
:
save_offload_index
(
offload_index
,
offload_folder
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment