Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
58bf2682
Unverified
Commit
58bf2682
authored
Aug 14, 2025
by
Sayak Paul
Committed by
GitHub
Aug 14, 2025
Browse files
support `hf_quantizer` in cache warmup. (#12043)
* support hf_quantizer in cache warmup. * reviewer feedback * up * up
parent
1b48db4c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
9 deletions
+50
-9
src/diffusers/models/model_loading_utils.py
src/diffusers/models/model_loading_utils.py
+11
-6
src/diffusers/models/modeling_utils.py
src/diffusers/models/modeling_utils.py
+2
-3
src/diffusers/quantizers/base.py
src/diffusers/quantizers/base.py
+11
-0
src/diffusers/quantizers/torchao/torchao_quantizer.py
src/diffusers/quantizers/torchao/torchao_quantizer.py
+26
-0
No files found.
src/diffusers/models/model_loading_utils.py
View file @
58bf2682
...
...
@@ -17,7 +17,6 @@
import
functools
import
importlib
import
inspect
import
math
import
os
from
array
import
array
from
collections
import
OrderedDict
,
defaultdict
...
...
@@ -717,27 +716,33 @@ def _expand_device_map(device_map, param_names):
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def
_caching_allocator_warmup
(
model
,
expanded_device_map
:
Dict
[
str
,
torch
.
device
],
dtype
:
torch
.
dtype
)
->
None
:
def
_caching_allocator_warmup
(
model
,
expanded_device_map
:
Dict
[
str
,
torch
.
device
],
dtype
:
torch
.
dtype
,
hf_quantizer
:
Optional
[
DiffusersQuantizer
]
)
->
None
:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
factor
=
2
if
hf_quantizer
is
None
else
hf_quantizer
.
get_cuda_warm_up_factor
()
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map
=
{
param
:
torch
.
device
(
device
)
for
param
,
device
in
expanded_device_map
.
items
()
if
str
(
device
)
not
in
[
"cpu"
,
"disk"
]
}
parame
te
r
_count
=
defaultdict
(
lambda
:
0
)
total_by
te_count
=
defaultdict
(
lambda
:
0
)
for
param_name
,
device
in
accelerator_device_map
.
items
():
try
:
param
=
model
.
get_parameter
(
param_name
)
except
AttributeError
:
param
=
model
.
get_buffer
(
param_name
)
parameter_count
[
device
]
+=
math
.
prod
(
param
.
shape
)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count
=
param
.
numel
()
*
param
.
element_size
()
# TODO: account for TP when needed.
total_byte_count
[
device
]
+=
param_byte_count
# This will kick off the caching allocator to avoid having to Malloc afterwards
for
device
,
param
_count
in
parame
te
r
_count
.
items
():
_
=
torch
.
empty
(
param
_count
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
for
device
,
byte
_count
in
total_by
te_count
.
items
():
_
=
torch
.
empty
(
byte
_count
//
factor
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
src/diffusers/models/modeling_utils.py
View file @
58bf2682
...
...
@@ -1532,10 +1532,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer
if
device_map
is
not
None
and
hf_quantizer
is
None
:
if
device_map
is
not
None
:
expanded_device_map
=
_expand_device_map
(
device_map
,
expected_keys
)
_caching_allocator_warmup
(
model
,
expanded_device_map
,
dtype
)
_caching_allocator_warmup
(
model
,
expanded_device_map
,
dtype
,
hf_quantizer
)
offload_index
=
{}
if
device_map
is
not
None
and
"disk"
in
device_map
.
values
()
else
None
state_dict_folder
,
state_dict_index
=
None
,
None
...
...
src/diffusers/quantizers/base.py
View file @
58bf2682
...
...
@@ -209,6 +209,17 @@ class DiffusersQuantizer(ABC):
return
model
def
get_cuda_warm_up_factor
(
self
):
"""
The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
we allocate half the memory of the weights residing in the empty model, etc...
"""
# By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
# really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
# weight loading)
return
4
def
_dequantize
(
self
,
model
):
raise
NotImplementedError
(
f
"
{
self
.
quantization_config
.
quant_method
}
has no implementation of `dequantize`, please raise an issue on GitHub."
...
...
src/diffusers/quantizers/torchao/torchao_quantizer.py
View file @
58bf2682
...
...
@@ -19,6 +19,7 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
import
importlib
import
types
from
fnmatch
import
fnmatch
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Union
from
packaging
import
version
...
...
@@ -278,6 +279,31 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
module
.
_parameters
[
tensor_name
]
=
torch
.
nn
.
Parameter
(
param_value
).
to
(
device
=
target_device
)
quantize_
(
module
,
self
.
quantization_config
.
get_apply_tensor_subclass
())
def
get_cuda_warm_up_factor
(
self
):
"""
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
- A factor of 2 means we pre-allocate the full memory footprint of the model.
- A factor of 4 means we pre-allocate half of that, and so on
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
torch_dtype not the actual bit-width of the quantized data.
To correct for this:
- Use a division factor of 8 for int4 weights
- Use a division factor of 4 for int8 weights
"""
# Original mapping for non-AOBaseConfig types
# For the uint types, this is a best guess. Once these types become more used
# we can look into their nuances.
map_to_target_dtype
=
{
"int4_*"
:
8
,
"int8_*"
:
4
,
"uint*"
:
8
,
"float8*"
:
4
}
quant_type
=
self
.
quantization_config
.
quant_type
for
pattern
,
target_dtype
in
map_to_target_dtype
.
items
():
if
fnmatch
(
quant_type
,
pattern
):
return
target_dtype
raise
ValueError
(
f
"Unsupported quant_type:
{
quant_type
!
r
}
"
)
def
_process_model_before_weight_loading
(
self
,
model
:
"ModelMixin"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment