Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9e7ae568
Unverified
Commit
9e7ae568
authored
Sep 10, 2025
by
Sayak Paul
Committed by
GitHub
Sep 10, 2025
Browse files
[feat] cache allocator warmup for `from_single_model` (#12305)
* add * add a test
parent
f7b79452
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
12 deletions
+31
-12
src/diffusers/loaders/single_file_model.py
src/diffusers/loaders/single_file_model.py
+23
-12
tests/single_file/test_model_flux_transformer_single_file.py
tests/single_file/test_model_flux_transformer_single_file.py
+8
-0
No files found.
src/diffusers/loaders/single_file_model.py
View file @
9e7ae568
...
@@ -22,6 +22,7 @@ from huggingface_hub.utils import validate_hf_hub_args
...
@@ -22,6 +22,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from
typing_extensions
import
Self
from
typing_extensions
import
Self
from
..
import
__version__
from
..
import
__version__
from
..models.model_loading_utils
import
_caching_allocator_warmup
,
_determine_device_map
,
_expand_device_map
from
..quantizers
import
DiffusersAutoQuantizer
from
..quantizers
import
DiffusersAutoQuantizer
from
..utils
import
deprecate
,
is_accelerate_available
,
is_torch_version
,
logging
from
..utils
import
deprecate
,
is_accelerate_available
,
is_torch_version
,
logging
from
..utils.torch_utils
import
empty_device_cache
from
..utils.torch_utils
import
empty_device_cache
...
@@ -297,6 +298,7 @@ class FromOriginalModelMixin:
...
@@ -297,6 +298,7 @@ class FromOriginalModelMixin:
low_cpu_mem_usage
=
kwargs
.
pop
(
"low_cpu_mem_usage"
,
_LOW_CPU_MEM_USAGE_DEFAULT
)
low_cpu_mem_usage
=
kwargs
.
pop
(
"low_cpu_mem_usage"
,
_LOW_CPU_MEM_USAGE_DEFAULT
)
device
=
kwargs
.
pop
(
"device"
,
None
)
device
=
kwargs
.
pop
(
"device"
,
None
)
disable_mmap
=
kwargs
.
pop
(
"disable_mmap"
,
False
)
disable_mmap
=
kwargs
.
pop
(
"disable_mmap"
,
False
)
device_map
=
kwargs
.
pop
(
"device_map"
,
None
)
user_agent
=
{
"diffusers"
:
__version__
,
"file_type"
:
"single_file"
,
"framework"
:
"pytorch"
}
user_agent
=
{
"diffusers"
:
__version__
,
"file_type"
:
"single_file"
,
"framework"
:
"pytorch"
}
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
...
@@ -403,19 +405,8 @@ class FromOriginalModelMixin:
...
@@ -403,19 +405,8 @@ class FromOriginalModelMixin:
with
ctx
():
with
ctx
():
model
=
cls
.
from_config
(
diffusers_model_config
)
model
=
cls
.
from_config
(
diffusers_model_config
)
checkpoint_mapping_kwargs
=
_get_mapping_function_kwargs
(
checkpoint_mapping_fn
,
**
kwargs
)
model_state_dict
=
model
.
state_dict
()
if
_should_convert_state_dict_to_diffusers
(
model
.
state_dict
(),
checkpoint
):
diffusers_format_checkpoint
=
checkpoint_mapping_fn
(
config
=
diffusers_model_config
,
checkpoint
=
checkpoint
,
**
checkpoint_mapping_kwargs
)
else
:
diffusers_format_checkpoint
=
checkpoint
if
not
diffusers_format_checkpoint
:
raise
SingleFileComponentError
(
f
"Failed to load
{
mapping_class_name
}
. Weights for this component appear to be missing in the checkpoint."
)
# Check if `_keep_in_fp32_modules` is not None
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules
=
(
cls
.
_keep_in_fp32_modules
is
not
None
)
and
(
use_keep_in_fp32_modules
=
(
cls
.
_keep_in_fp32_modules
is
not
None
)
and
(
(
torch_dtype
==
torch
.
float16
)
or
hasattr
(
hf_quantizer
,
"use_keep_in_fp32_modules"
)
(
torch_dtype
==
torch
.
float16
)
or
hasattr
(
hf_quantizer
,
"use_keep_in_fp32_modules"
)
...
@@ -428,6 +419,26 @@ class FromOriginalModelMixin:
...
@@ -428,6 +419,26 @@ class FromOriginalModelMixin:
else
:
else
:
keep_in_fp32_modules
=
[]
keep_in_fp32_modules
=
[]
# Now that the model is loaded, we can determine the `device_map`
device_map
=
_determine_device_map
(
model
,
device_map
,
None
,
torch_dtype
,
keep_in_fp32_modules
,
hf_quantizer
)
if
device_map
is
not
None
:
expanded_device_map
=
_expand_device_map
(
device_map
,
model_state_dict
.
keys
())
_caching_allocator_warmup
(
model
,
expanded_device_map
,
torch_dtype
,
hf_quantizer
)
checkpoint_mapping_kwargs
=
_get_mapping_function_kwargs
(
checkpoint_mapping_fn
,
**
kwargs
)
if
_should_convert_state_dict_to_diffusers
(
model_state_dict
,
checkpoint
):
diffusers_format_checkpoint
=
checkpoint_mapping_fn
(
config
=
diffusers_model_config
,
checkpoint
=
checkpoint
,
**
checkpoint_mapping_kwargs
)
else
:
diffusers_format_checkpoint
=
checkpoint
if
not
diffusers_format_checkpoint
:
raise
SingleFileComponentError
(
f
"Failed to load
{
mapping_class_name
}
. Weights for this component appear to be missing in the checkpoint."
)
if
hf_quantizer
is
not
None
:
if
hf_quantizer
is
not
None
:
hf_quantizer
.
preprocess_model
(
hf_quantizer
.
preprocess_model
(
model
=
model
,
model
=
model
,
...
...
tests/single_file/test_model_flux_transformer_single_file.py
View file @
9e7ae568
...
@@ -69,3 +69,11 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
...
@@ -69,3 +69,11 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
del
model
del
model
gc
.
collect
()
gc
.
collect
()
backend_empty_cache
(
torch_device
)
backend_empty_cache
(
torch_device
)
def
test_device_map_cuda
(
self
):
backend_empty_cache
(
torch_device
)
model
=
self
.
model_class
.
from_single_file
(
self
.
ckpt_path
,
device_map
=
"cuda"
)
del
model
gc
.
collect
()
backend_empty_cache
(
torch_device
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment