Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3d8d8485
Unverified
Commit
3d8d8485
authored
Jun 20, 2025
by
Sayak Paul
Committed by
GitHub
Jun 20, 2025
Browse files
fix invalid component handling behaviour in `PipelineQuantizationConfig` (#11750)
* start * updates
parent
195926bb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
80 additions
and
0 deletions
+80
-0
src/diffusers/pipelines/pipeline_loading_utils.py
src/diffusers/pipelines/pipeline_loading_utils.py
+23
-0
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+2
-0
tests/quantization/test_pipeline_level_quantization.py
tests/quantization/test_pipeline_level_quantization.py
+55
-0
No files found.
src/diffusers/pipelines/pipeline_loading_utils.py
View file @
3d8d8485
...
...
@@ -1131,3 +1131,26 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
break
if
has_transformers_component
and
not
is_transformers_version
(
">"
,
"4.47.1"
):
raise
ValueError
(
"Please upgrade your `transformers` installation to the latest version to use DDUF."
)
def
_maybe_warn_for_wrong_component_in_quant_config
(
pipe_init_dict
,
quant_config
):
if
quant_config
is
None
:
return
actual_pipe_components
=
set
(
pipe_init_dict
.
keys
())
missing
=
""
quant_components
=
None
if
getattr
(
quant_config
,
"components_to_quantize"
,
None
)
is
not
None
:
quant_components
=
set
(
quant_config
.
components_to_quantize
)
elif
getattr
(
quant_config
,
"quant_mapping"
,
None
)
is
not
None
and
isinstance
(
quant_config
.
quant_mapping
,
dict
):
quant_components
=
set
(
quant_config
.
quant_mapping
.
keys
())
if
quant_components
and
not
quant_components
.
issubset
(
actual_pipe_components
):
missing
=
quant_components
-
actual_pipe_components
if
missing
:
logger
.
warning
(
f
"The following components in the quantization config
{
missing
}
will be ignored "
"as they do not belong to the underlying pipeline. Acceptable values for the pipeline "
f
"components are:
{
', '
.
join
(
actual_pipe_components
)
}
."
)
src/diffusers/pipelines/pipeline_utils.py
View file @
3d8d8485
...
...
@@ -88,6 +88,7 @@ from .pipeline_loading_utils import (
_identify_model_variants
,
_maybe_raise_error_for_incorrect_transformers
,
_maybe_raise_warning_for_inpainting
,
_maybe_warn_for_wrong_component_in_quant_config
,
_resolve_custom_pipeline_and_cls
,
_unwrap_model
,
_update_init_kwargs_with_connected_pipeline
,
...
...
@@ -984,6 +985,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 7. Load each module in the pipeline
current_device_map
=
None
_maybe_warn_for_wrong_component_in_quant_config
(
init_dict
,
quantization_config
)
for
name
,
(
library_name
,
class_name
)
in
logging
.
tqdm
(
init_dict
.
items
(),
desc
=
"Loading pipeline components..."
):
# 7.1 device_map shenanigans
if
final_device_map
is
not
None
and
len
(
final_device_map
)
>
0
:
...
...
tests/quantization/test_pipeline_level_quantization.py
View file @
3d8d8485
...
...
@@ -16,10 +16,13 @@ import tempfile
import
unittest
import
torch
from
parameterized
import
parameterized
from
diffusers
import
DiffusionPipeline
,
QuantoConfig
from
diffusers.quantizers
import
PipelineQuantizationConfig
from
diffusers.utils
import
logging
from
diffusers.utils.testing_utils
import
(
CaptureLogger
,
is_transformers_available
,
require_accelerate
,
require_bitsandbytes_version_greater
,
...
...
@@ -188,3 +191,55 @@ class PipelineQuantizationTests(unittest.TestCase):
output_2
=
loaded_pipe
(
**
pipe_inputs
,
generator
=
torch
.
manual_seed
(
self
.
seed
)).
images
self
.
assertTrue
(
torch
.
allclose
(
output_1
,
output_2
))
@
parameterized
.
expand
([
"quant_kwargs"
,
"quant_mapping"
])
def
test_warn_invalid_component
(
self
,
method
):
invalid_component
=
"foo"
if
method
==
"quant_kwargs"
:
components_to_quantize
=
[
"transformer"
,
invalid_component
]
quant_config
=
PipelineQuantizationConfig
(
quant_backend
=
"bitsandbytes_8bit"
,
quant_kwargs
=
{
"load_in_8bit"
:
True
},
components_to_quantize
=
components_to_quantize
,
)
else
:
quant_config
=
PipelineQuantizationConfig
(
quant_mapping
=
{
"transformer"
:
QuantoConfig
(
"int8"
),
invalid_component
:
TranBitsAndBytesConfig
(
load_in_8bit
=
True
),
}
)
logger
=
logging
.
get_logger
(
"diffusers.pipelines.pipeline_loading_utils"
)
logger
.
setLevel
(
logging
.
WARNING
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
_
=
DiffusionPipeline
.
from_pretrained
(
self
.
model_name
,
quantization_config
=
quant_config
,
torch_dtype
=
torch
.
bfloat16
,
)
self
.
assertTrue
(
invalid_component
in
cap_logger
.
out
)
@
parameterized
.
expand
([
"quant_kwargs"
,
"quant_mapping"
])
def
test_no_quantization_for_all_invalid_components
(
self
,
method
):
invalid_component
=
"foo"
if
method
==
"quant_kwargs"
:
components_to_quantize
=
[
invalid_component
]
quant_config
=
PipelineQuantizationConfig
(
quant_backend
=
"bitsandbytes_8bit"
,
quant_kwargs
=
{
"load_in_8bit"
:
True
},
components_to_quantize
=
components_to_quantize
,
)
else
:
quant_config
=
PipelineQuantizationConfig
(
quant_mapping
=
{
invalid_component
:
TranBitsAndBytesConfig
(
load_in_8bit
=
True
)}
)
pipe
=
DiffusionPipeline
.
from_pretrained
(
self
.
model_name
,
quantization_config
=
quant_config
,
torch_dtype
=
torch
.
bfloat16
,
)
for
name
,
component
in
pipe
.
components
.
items
():
if
isinstance
(
component
,
torch
.
nn
.
Module
):
self
.
assertTrue
(
not
hasattr
(
component
.
config
,
"quantization_config"
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment