Unverified Commit 3d8d8485 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

fix invalid component handling behaviour in `PipelineQuantizationConfig` (#11750)

* start

* updates
parent 195926bb
...@@ -1131,3 +1131,26 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict): ...@@ -1131,3 +1131,26 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
break break
if has_transformers_component and not is_transformers_version(">", "4.47.1"): if has_transformers_component and not is_transformers_version(">", "4.47.1"):
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.") raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
def _maybe_warn_for_wrong_component_in_quant_config(pipe_init_dict, quant_config):
if quant_config is None:
return
actual_pipe_components = set(pipe_init_dict.keys())
missing = ""
quant_components = None
if getattr(quant_config, "components_to_quantize", None) is not None:
quant_components = set(quant_config.components_to_quantize)
elif getattr(quant_config, "quant_mapping", None) is not None and isinstance(quant_config.quant_mapping, dict):
quant_components = set(quant_config.quant_mapping.keys())
if quant_components and not quant_components.issubset(actual_pipe_components):
missing = quant_components - actual_pipe_components
if missing:
logger.warning(
f"The following components in the quantization config {missing} will be ignored "
"as they do not belong to the underlying pipeline. Acceptable values for the pipeline "
f"components are: {', '.join(actual_pipe_components)}."
)
...@@ -88,6 +88,7 @@ from .pipeline_loading_utils import ( ...@@ -88,6 +88,7 @@ from .pipeline_loading_utils import (
_identify_model_variants, _identify_model_variants,
_maybe_raise_error_for_incorrect_transformers, _maybe_raise_error_for_incorrect_transformers,
_maybe_raise_warning_for_inpainting, _maybe_raise_warning_for_inpainting,
_maybe_warn_for_wrong_component_in_quant_config,
_resolve_custom_pipeline_and_cls, _resolve_custom_pipeline_and_cls,
_unwrap_model, _unwrap_model,
_update_init_kwargs_with_connected_pipeline, _update_init_kwargs_with_connected_pipeline,
...@@ -984,6 +985,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -984,6 +985,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 7. Load each module in the pipeline # 7. Load each module in the pipeline
current_device_map = None current_device_map = None
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 7.1 device_map shenanigans # 7.1 device_map shenanigans
if final_device_map is not None and len(final_device_map) > 0: if final_device_map is not None and len(final_device_map) > 0:
......
...@@ -16,10 +16,13 @@ import tempfile ...@@ -16,10 +16,13 @@ import tempfile
import unittest import unittest
import torch import torch
from parameterized import parameterized
from diffusers import DiffusionPipeline, QuantoConfig from diffusers import DiffusionPipeline, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import logging
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger,
is_transformers_available, is_transformers_available,
require_accelerate, require_accelerate,
require_bitsandbytes_version_greater, require_bitsandbytes_version_greater,
...@@ -188,3 +191,55 @@ class PipelineQuantizationTests(unittest.TestCase): ...@@ -188,3 +191,55 @@ class PipelineQuantizationTests(unittest.TestCase):
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
self.assertTrue(torch.allclose(output_1, output_2)) self.assertTrue(torch.allclose(output_1, output_2))
@parameterized.expand(["quant_kwargs", "quant_mapping"])
def test_warn_invalid_component(self, method):
invalid_component = "foo"
if method == "quant_kwargs":
components_to_quantize = ["transformer", invalid_component]
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=components_to_quantize,
)
else:
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig("int8"),
invalid_component: TranBitsAndBytesConfig(load_in_8bit=True),
}
)
logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
_ = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(invalid_component in cap_logger.out)
@parameterized.expand(["quant_kwargs", "quant_mapping"])
def test_no_quantization_for_all_invalid_components(self, method):
invalid_component = "foo"
if method == "quant_kwargs":
components_to_quantize = [invalid_component]
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=components_to_quantize,
)
else:
quant_config = PipelineQuantizationConfig(
quant_mapping={invalid_component: TranBitsAndBytesConfig(load_in_8bit=True)}
)
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
for name, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
self.assertTrue(not hasattr(component.config, "quantization_config"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment