Unverified Commit 599c8871 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

feat: pipeline-level quantization config (#11130)



* feat: pipeline-level quant config.
Co-authored-by: default avatarSunMarc <marc.sun@hotmail.fr>

condition better.

support mapping.

improvements.

[Quantization] Add Quanto backend (#10756)

* update

* updaet

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/quanto.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/quanto/utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* update

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

[Single File] Add single file loading for SANA Transformer (#10947)

* added support for from_single_file

* added diffusers mapping script

* added testcase

* bug fix

* updated tests

* corrected code quality

* corrected code quality

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)

* updates

* updates

* updates

* updates

* notebooks revert

* fix-copies.

* seeing

* fix

* revert

* fixes

* fixes

* fixes

* remove print

* fix

* conflicts ii.

* updates

* fixes

* better filtering of prefix.

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>

[LoRA] CogView4 (#10981)

* update

* make fix-copies

* update

[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)

* memory usage tests

* fixes

* gguf

[`Research Project`] Add AnyText: Multilingual Visual Text Generation And Editing (#8998)

* Add initial template

* Second template

* feat: Add TextEmbeddingModule to AnyTextPipeline

* feat: Add AuxiliaryLatentModule template to AnyTextPipeline

* Add bert tokenizer from the anytext repo for now

* feat: Update AnyTextPipeline's modify_prompt method

This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe.

* Fill in the `forward` pass of `AuxiliaryLatentModule`

* `make style && make quality`

* `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library`

* Update error handling to raise and logging

* Add `create_glyph_lines` function into `TextEmbeddingModule`

* make style

* Up

* Up

* Up

* Up

* Remove several comments

* refactor: Remove ControlNetConditioningEmbedding and update code accordingly

* Up

* Up

* up

* refactor: Update AnyTextPipeline to include new optional parameters

* up

* feat: Add OCR model and its components

* chore: Update `TextEmbeddingModule` to include OCR model components and dependencies

* chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task

* `make style`

* refactor: Update `AnyTextPipeline`'s docstring

* Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once

* simplify

* `make style`

* Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function

* Simplify for now

* `make style`

* Up

* feat: Add scripts to convert AnyText controlnet to diffusers

* `make style`

* Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule`

* make style

* Up

* Simplify

* Up

* feat: Add safetensors module for loading model file

* Fix device issues

* Up

* Up

* refactor: Simplify

* refactor: Simplify code for loading models and handling data types

* `make style`

* refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule

* refactor: Update dtype in embedding_manager.py to match proj.weight

* Up

* Add attribution and adaptation information to pipeline_anytext.py

* Update usage example

* Will refactor `controlnet_cond_embedding` initialization

* Add `AnyTextControlNetConditioningEmbedding` template

* Refactor organization

* style

* style

* Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding`

* Follow one-file policy

* style

* [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel

* [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py

* [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py

* Refactor AnyTextControlNet to use configurable conditioning embedding channels

* Complete control net conditioning embedding in AnyTextControlNetModel

* up

* [FIX] Ensure embeddings use correct device in AnyTextControlNetModel

* up

* up

* style

* [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline

* [UPDATE] Update example code in anytext.py to use correct font file and improve clarity

* down

* [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing

* update pillow

* [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity

* [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file

* [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency

* 🆙



* style

* [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py

* style

* Update examples/research_projects/anytext/README.md
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Remove commented-out image preparation code in AnyTextPipeline

* Remove unnecessary blank line in README.md

[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6  (#11018)

* update

* update

* update

* update

* update

* update

* update

* update

* update

fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings  (#11012)

small fix on generating time_ids & embeddings

[LoRA] support wan i2v loras from the world. (#11025)

* support wan i2v loras from the world.

* remove copied from.

* upates

* add lora.

Fix SD3 IPAdapter feature extractor (#11027)

chore: fix help messages in advanced diffusion examples (#10923)

Fix missing **kwargs in lora_pipeline.py (#11011)

* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>

Fix for multi-GPU WAN inference (#10997)

Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs

Co-authored-by: Jimmy <39@🇺🇸.com>

[Refactor] Clean up import utils boilerplate (#11026)

* update

* update

* update

Use `output_size` in `repeat_interleave` (#11030)

[hybrid inference 🍯🐝] Add VAE encode (#11017)

* [hybrid inference 🍯🐝

] Add VAE encode

* _toctree: add vae encode

* Add endpoints, tests

* vae_encode docs

* vae encode benchmarks

* api reference

* changelog

* Update docs/source/en/hybrid_inference/overview.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)

* Wan Pipeline scaling fix, type hint warning, multi generator fix

* Apply suggestions from code review

[LoRA] change to warning from info when notifying the users about a LoRA no-op (#11044)

* move to warning.

* test related changes.

Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (#10827)

* Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

making ```formatted_images``` initialization compact (#10801)

compact writing
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)

* get_1d_rotary_pos_embed support npu

* Update src/diffusers/models/embeddings.py

---------
Co-authored-by: default avatarKai zheng <kaizheng@KaideMacBook-Pro.local>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

[Tests] restrict memory tests for quanto for certain schemes. (#11052)

* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* fixes

* style

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

[LoRA] feat: support non-diffusers wan t2v loras. (#11059)

feat: support non-diffusers wan t2v loras.

[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)

Fix: dtype mismatch of prompt embeddings in sd3 controlnet training
Co-authored-by: default avatarAndreas Jörg <andreasjoerg@MacBook-Pro-von-Andreas-2.fritz.box>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)

reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.
Co-authored-by: default avatarJuan Acevedo <jfacevedo@google.com>

Fix deterministic issue when getting pipeline dtype and device (#10696)
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

[Tests] add requires peft decorator. (#11037)

* add requires peft decorator.

* install peft conditionally.

* conditional deps.
Co-authored-by: default avatarDN6 <dhruv.nair@gmail.com>

---------
Co-authored-by: default avatarDN6 <dhruv.nair@gmail.com>

CogView4 Control Block (#10809)

* cogview4 control training

---------
Co-authored-by: default avatarOleehyO <leehy0357@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>

[CI] pin transformers version for benchmarking. (#11067)

pin transformers version for benchmarking.

updates

Fix Wan I2V Quality (#11087)

* fix_wan_i2v_quality

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update pipeline_wan_i2v.py

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>

LTX 0.9.5 (#10968)

* update

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>

make PR GPU tests conditioned on styling. (#11099)

Group offloading improvements (#11094)

update

Fix pipeline_flux_controlnet.py (#11095)

* Fix pipeline_flux_controlnet.py

* Fix style

update readme instructions. (#11096)
Co-authored-by: default avatarJuan Acevedo <jfacevedo@google.com>

Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)

Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP

Fix Group offloading behaviour when using streams (#11097)

* update

* update

Quality options in `export_to_video` (#11090)

* Quality options in `export_to_video`

* make style

improve more.

add placeholders for docstrings.

formatting.

smol fix.

solidify validation and annotation

* Revert "feat: pipeline-level quant config."

This reverts commit 316ff46b7648bfa24525ac02c284afcf440404aa.

* feat: implement pipeline-level quantization config
Co-authored-by: default avatarSunMarc <marc@huggingface.co>

* update

* fixes

* fix validation.

* add tests and other improvements.

* add tests

* import quality

* remove prints.

* add docs.

* fixes to docs.

* doc fixes.

* doc fixes.

* add validation to the input quantization_config.

* clarify recommendations.

* docs

* add to ci.

* todo.

---------
Co-authored-by: default avatarSunMarc <marc@huggingface.co>
parent 393aefcd
......@@ -525,6 +525,60 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_nightly_pipeline_level_quantization_tests:
name: Torch quantization nightly tests
strategy:
fail-fast: false
max-parallel: 2
runs-on:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "20gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install -U bitsandbytes optimum_quanto
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
- name: Pipeline-level quantization tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_pipeline_level_quant_torch_cuda \
--report-log=tests_pipeline_level_quant_torch_cuda.log \
tests/quantization/test_pipeline_level_quantization.py
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_cuda_pipeline_level_quant_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
# M1 runner currently not well supported
# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
# run_nightly_tests_apple_m1:
......
......@@ -13,9 +13,7 @@ specific language governing permissions and limitations under the License.
# Quantization
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).
Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference.
<Tip>
......@@ -23,6 +21,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
</Tip>
## PipelineQuantizationConfig
[[autodoc]] quantizers.PipelineQuantizationConfig
## BitsAndBytesConfig
......
......@@ -39,3 +39,90 @@ Diffusers currently supports the following quantization methods.
- [Quanto](./quanto.md)
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
## Pipeline-level quantization
Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply
quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can
do this with [`~quantizers.PipelineQuantizationConfig`].
Start by defining a `PipelineQuantizationConfig`:
```py
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers.quantization_config import QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from transformers import BitsAndBytesConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": BitsAndBytesConfig(
load_in_4bit=True, compute_dtype=torch.bfloat16
),
}
)
```
Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference:
```py
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
).to("cuda")
image = pipe("photo of a cute dog").images[0]
```
This method allows for more granular control over the quantization specifications of individual
model-level components of a pipeline. It also allows for different quantization backends for
different components. In the above example, you used a combination of Quanto and BitsandBytes. However,
one caveat of this method is that users need to know which components come from `transformers` to be able
to import the right quantization config class.
The other method is simpler in terms of experience but is
less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way:
```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer", "text_encoder_2"],
)
```
This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example.
In this case, `quant_kwargs` will be used to initialize the quantization specifications
of the respective quantization configuration class of `quant_backend`. `components_to_quantize`
is used to denote the components that will be quantized. For most pipelines, you would want to
keep `transformer` in the list as that is often the most compute and memory intensive.
The config below will work for most diffusion pipelines that have a `transformer` component present.
In most case, you will want to quantize the `transformer` component as that is often the most compute-
intensive part of a diffusion pipeline.
```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer"],
)
```
Below is a list of the supported quantization backends available in both `diffusers` and `transformers`:
* `bitsandbytes_4bit`
* `bitsandbytes_8bit`
* `gguf`
* `quanto`
* `torchao`
Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's
recommended to quantize the text encoders that are memory-intensive. Some examples include T5,
Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through
`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`).
\ No newline at end of file
......@@ -675,8 +675,10 @@ def load_sub_model(
use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
quantization_config: Optional[Any] = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
from ..quantizers import PipelineQuantizationConfig
# retrieve class candidates
......@@ -769,6 +771,17 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False
if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)
and issubclass(class_obj, torch.nn.Module)
):
model_quant_config = quantization_config._resolve_quant_config(
is_diffusers=is_diffusers_model, module_name=name
)
if model_quant_config is not None:
loading_kwargs["quantization_config"] = model_quant_config
# check if the module is in a subdirectory
if dduf_entries:
loading_kwargs["dduf_entries"] = dduf_entries
......
......@@ -47,6 +47,7 @@ from ..configuration_utils import ConfigMixin
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..quantizers import PipelineQuantizationConfig
from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
......@@ -725,6 +726,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
......@@ -741,6 +743,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
" install accelerate\n```\n."
)
if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
......@@ -1001,6 +1006,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_safetensors=use_safetensors,
dduf_entries=dduf_entries,
provider_options=provider_options,
quantization_config=quantization_config,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
......
......@@ -12,5 +12,183 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Dict, List, Optional, Union
from ..utils import is_transformers_available, logging
from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
try:
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
except ImportError:
class TransformersQuantConfigMixin:
pass
logger = logging.get_logger(__name__)
class PipelineQuantizationConfig:
"""
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
Args:
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
is available to both `diffusers` and `transformers`.
quant_kwargs (`dict`): Params to initialize the quantization backend class.
components_to_quantize (`list`): Components of a pipeline to be quantized.
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
and `components_to_quantize`.
"""
def __init__(
self,
quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
components_to_quantize: Optional[List[str]] = None,
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
):
self.quant_backend = quant_backend
# Initialize kwargs to be {} to set to the defaults.
self.quant_kwargs = quant_kwargs or {}
self.components_to_quantize = components_to_quantize
self.quant_mapping = quant_mapping
self.post_init()
def post_init(self):
quant_mapping = self.quant_mapping
self.is_granular = True if quant_mapping is not None else False
self._validate_init_args()
def _validate_init_args(self):
if self.quant_backend and self.quant_mapping:
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
if not self.quant_mapping and not self.quant_backend:
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
if not self.quant_kwargs and not self.quant_mapping:
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
if self.quant_backend is not None:
self._validate_init_kwargs_in_backends()
if self.quant_mapping is not None:
self._validate_quant_mapping_args()
def _validate_init_kwargs_in_backends(self):
quant_backend = self.quant_backend
self._check_backend_availability(quant_backend)
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
if quant_config_mapping_transformers is not None:
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
else:
init_kwargs_transformers = None
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
if init_kwargs_transformers != init_kwargs_diffusers:
raise ValueError(
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
"this mapping would look like."
)
def _validate_quant_mapping_args(self):
quant_mapping = self.quant_mapping
transformers_map, diffusers_map = self._get_quant_config_list()
available_transformers = list(transformers_map.values()) if transformers_map else None
available_diffusers = list(diffusers_map.values())
for module_name, config in quant_mapping.items():
if any(isinstance(config, cfg) for cfg in available_diffusers):
continue
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
continue
if available_transformers:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}; "
f"Available transformers configs: {available_transformers}."
)
else:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}."
)
def _check_backend_availability(self, quant_backend: str):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
available_backends_transformers = (
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
)
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
if (
available_backends_transformers and quant_backend not in available_backends_transformers
) or quant_backend not in quant_config_mapping_diffusers:
error_message = f"Provided quant_backend={quant_backend} was not found."
if available_backends_transformers:
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
raise ValueError(error_message)
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
quant_mapping = self.quant_mapping
components_to_quantize = self.components_to_quantize
# Granular case
if self.is_granular and module_name in quant_mapping:
logger.debug(f"Initializing quantization config class for {module_name}.")
config = quant_mapping[module_name]
return config
# Global config case
else:
should_quantize = False
# Only quantize the modules requested for.
if components_to_quantize and module_name in components_to_quantize:
should_quantize = True
# No specification for `components_to_quantize` means all modules should be quantized.
elif not self.is_granular and not components_to_quantize:
should_quantize = True
if should_quantize:
logger.debug(f"Initializing quantization config class for {module_name}.")
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
quant_config_cls = mapping_to_use[self.quant_backend]
quant_kwargs = self.quant_kwargs
return quant_config_cls(**quant_kwargs)
# Fallback: no applicable configuration found.
return None
def _get_quant_config_list(self):
if is_transformers_available():
from transformers.quantizers.auto import (
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
)
else:
quant_config_mapping_transformers = None
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
return quant_config_mapping_transformers, quant_config_mapping_diffusers
......@@ -75,7 +75,7 @@ class QuantizationConfigMixin:
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
return_unused_kwargs (`bool`,*optional*, defaults to `False`):
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
`PreTrainedModel`.
kwargs (`Dict[str, Any]`):
......
......@@ -38,6 +38,7 @@ from .import_utils import (
is_note_seq_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
is_peft_available,
is_timm_available,
is_torch_available,
......@@ -486,6 +487,13 @@ def require_bitsandbytes(test_case):
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
def require_quanto(test_case):
"""
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
"""
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
......
# coding=utf-8
# Copyright 2024 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import torch
from diffusers import DiffusionPipeline, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.testing_utils import (
is_transformers_available,
require_accelerate,
require_bitsandbytes_version_greater,
require_quanto,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
if is_transformers_available():
from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig
else:
TranBitsAndBytesConfig = None
@require_bitsandbytes_version_greater("0.43.2")
@require_quanto
@require_accelerate
@require_torch
@require_torch_accelerator
@slow
class PipelineQuantizationTests(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
def test_quant_config_set_correctly_through_kwargs(self):
components_to_quantize = ["transformer", "text_encoder_2"]
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=components_to_quantize,
)
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for name, component in pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
def test_quant_config_set_correctly_through_granular(self):
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
components_to_quantize = list(quant_config.quant_mapping.keys())
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for name, component in pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
if name == "text_encoder_2":
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
else:
self.assertTrue(quantization_config.quant_method == "quanto")
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
def test_raises_error_for_invalid_config(self):
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
},
quant_backend="bitsandbytes_4bit",
)
self.assertTrue(
str(err_context.exception)
== "Both `quant_backend` and `quant_mapping` cannot be specified at the same time."
)
def test_validation_for_kwargs(self):
components_to_quantize = ["transformer", "text_encoder_2"]
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_backend="quanto",
quant_kwargs={"weights_dtype": "int8"},
components_to_quantize=components_to_quantize,
)
self.assertTrue(
"The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception)
)
def test_raises_error_for_wrong_config_class(self):
quant_config = {
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
with self.assertRaises(ValueError) as err_context:
_ = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(
str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`."
)
def test_validation_for_mapping(self):
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_mapping={
"transformer": DiffusionPipeline(),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception))
def test_saving_loading(self):
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
components_to_quantize = list(quant_config.quant_mapping.keys())
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"}
output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device)
for name, component in loaded_pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
if name == "text_encoder_2":
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
else:
self.assertTrue(quantization_config.quant_method == "quanto")
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
self.assertTrue(torch.allclose(output_1, output_2))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment