Unverified Commit 9f00c617 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] TorchAO Quantizer (#10009)



* torchao quantizer


---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent aafed3f8
...@@ -157,6 +157,8 @@ ...@@ -157,6 +157,8 @@
title: Getting Started title: Getting Started
- local: quantization/bitsandbytes - local: quantization/bitsandbytes
title: bitsandbytes title: bitsandbytes
- local: quantization/torchao
title: torchao
title: Quantization Methods title: Quantization Methods
- sections: - sections:
- local: optimization/fp16 - local: optimization/fp16
......
...@@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui ...@@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
[[autodoc]] BitsAndBytesConfig [[autodoc]] BitsAndBytesConfig
## TorchAoConfig
[[autodoc]] TorchAoConfig
## DiffusersQuantizer ## DiffusersQuantizer
[[autodoc]] quantizers.base.DiffusersQuantizer [[autodoc]] quantizers.base.DiffusersQuantizer
...@@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be ...@@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what? ## When to use what?
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use.
\ No newline at end of file \ No newline at end of file
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# torchao
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
```bash
pip install -U torch torchao
```
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
The example below only quantizes the weights to int8.
```python
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=dtype,
)
pipe = FluxPipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=dtype,
)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
image.save("output.png")
```
TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
```python
# In the above code, add the following after initializing the transformer
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
```
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
The `TorchAoConfig` class accepts three parameters:
- `quant_type`: A string value mentioning one of the quantization types below.
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
## Supported quantization types
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
The quantization methods supported are as follows:
| **Category** | **Full Function Names** | **Shorthands** |
|--------------|-------------------------|----------------|
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
## Resources
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)
...@@ -31,7 +31,7 @@ _import_structure = { ...@@ -31,7 +31,7 @@ _import_structure = {
"loaders": ["FromOriginalModelMixin"], "loaders": ["FromOriginalModelMixin"],
"models": [], "models": [],
"pipelines": [], "pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig"], "quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"],
"schedulers": [], "schedulers": [],
"utils": [ "utils": [
"OptionalDependencyNotAvailable", "OptionalDependencyNotAvailable",
...@@ -569,7 +569,7 @@ else: ...@@ -569,7 +569,7 @@ else:
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig
try: try:
if not is_onnx_available(): if not is_onnx_available():
......
...@@ -25,7 +25,6 @@ import safetensors ...@@ -25,7 +25,6 @@ import safetensors
import torch import torch
from huggingface_hub.utils import EntryNotFoundError from huggingface_hub.utils import EntryNotFoundError
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import ( from ..utils import (
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION, SAFETENSORS_FILE_EXTENSION,
...@@ -182,7 +181,6 @@ def load_model_dict_into_meta( ...@@ -182,7 +181,6 @@ def load_model_dict_into_meta(
device = device or torch.device("cpu") device = device or torch.device("cpu")
dtype = dtype or torch.float32 dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None is_quantized = hf_quantizer is not None
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
empty_state_dict = model.state_dict() empty_state_dict = model.state_dict()
...@@ -215,12 +213,12 @@ def load_model_dict_into_meta( ...@@ -215,12 +213,12 @@ def load_model_dict_into_meta(
# bnb params are flattened. # bnb params are flattened.
if empty_state_dict[param_name].shape != param.shape: if empty_state_dict[param_name].shape != param.shape:
if ( if (
is_quant_method_bnb is_quantized
and hf_quantizer.pre_quantized and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
): ):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
elif not is_quant_method_bnb: else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError( raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
......
...@@ -700,10 +700,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -700,10 +700,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer = None hf_quantizer = None
if hf_quantizer is not None: if hf_quantizer is not None:
if device_map is not None: is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
raise NotImplementedError( raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
) )
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
...@@ -858,13 +860,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -858,13 +860,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if device_map is None and not is_sharded: if device_map is None and not is_sharded:
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
# It would error out during the `validate_environment()` call above in the absence of cuda. # It would error out during the `validate_environment()` call above in the absence of cuda.
is_quant_method_bnb = (
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
)
if hf_quantizer is None: if hf_quantizer is None:
param_device = "cpu" param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor # TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb: else:
param_device = torch.device(torch.cuda.current_device()) param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant) state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict) model._convert_deprecated_attention_blocks(state_dict)
......
...@@ -19,17 +19,20 @@ import warnings ...@@ -19,17 +19,20 @@ import warnings
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
from .torchao import TorchAoHfQuantizer
AUTO_QUANTIZER_MAPPING = { AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"torchao": TorchAoHfQuantizer,
} }
AUTO_QUANTIZATION_CONFIG_MAPPING = { AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig,
"torchao": TorchAoConfig,
} }
......
...@@ -22,15 +22,17 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e ...@@ -22,15 +22,17 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
import copy import copy
import importlib.metadata import importlib.metadata
import inspect
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any, Dict, Union from functools import partial
from typing import Any, Dict, List, Optional, Union
from packaging import version from packaging import version
from ..utils import is_torch_available, logging from ..utils import is_torch_available, is_torchao_available, logging
if is_torch_available(): if is_torch_available():
...@@ -41,6 +43,7 @@ logger = logging.get_logger(__name__) ...@@ -41,6 +43,7 @@ logger = logging.get_logger(__name__)
class QuantizationMethod(str, Enum): class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes" BITS_AND_BYTES = "bitsandbytes"
TORCHAO = "torchao"
@dataclass @dataclass
...@@ -389,3 +392,254 @@ class BitsAndBytesConfig(QuantizationConfigMixin): ...@@ -389,3 +392,254 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
serializable_config_dict[key] = value serializable_config_dict[key] = value
return serializable_config_dict return serializable_config_dict
@dataclass
class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Args:
quant_type (`str`):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
`int8_weight_only`, `int8_dynamic_activation_int8_weight`
- Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq`
- **Floating point 8-bit quantization:**
- Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`,
`float8_static_activation_float8_weight`
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
`float8_e4m3_tensor`, `float8_e4m3_row`,
- **Floating point X-bit quantization:**
- Full function names: `fpx_weight_only`
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
be satisfied for a given shorthand notation.
- **Unsigned Integer quantization:**
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
kwargs (`Dict[str, Any]`, *optional*):
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
documentation of arguments can be found in
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
Example:
```python
from diffusers import FluxTransformer2DModel, TorchAoConfig
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
```
"""
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None:
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
# When we load from serialized config, "quant_type_kwargs" will be the key
if "quant_type_kwargs" in kwargs:
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
else:
self.quant_type_kwargs = kwargs
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
signature = inspect.signature(method)
all_kwargs = {
param.name
for param in signature.parameters.values()
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
}
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
if len(unsupported_kwargs) > 0:
raise ValueError(
f'The quantization method "{quant_type}" does not support the following keyword arguments: '
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
)
@classmethod
def _get_torchao_quant_type_to_method(cls):
r"""
Returns supported torchao quantization types with all commonly used notations.
"""
if is_torchao_available():
# TODO(aryan): Support autoquant and sparsify
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
from torchao.quantization.observer import PerRow, PerTensor
def generate_float8dq_types(dtype: torch.dtype):
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
types = {}
for granularity_cls in [PerTensor, PerRow]:
# Note: Activation and Weights cannot have different granularities
granularity_name = "tensor" if granularity_cls is PerTensor else "row"
types[f"float8dq_{name}_{granularity_name}"] = partial(
float8_dynamic_activation_float8_weight,
activation_dtype=dtype,
weight_dtype=dtype,
granularity=(granularity_cls(), granularity_cls()),
)
return types
def generate_fpx_quantization_types(bits: int):
types = {}
for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
return types
INT4_QUANTIZATION_TYPES = {
# int4 weight + bfloat16/float16 activation
"int4wo": int4_weight_only,
"int4_weight_only": int4_weight_only,
# int4 weight + int8 activation
"int4dq": int8_dynamic_activation_int4_weight,
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
}
INT8_QUANTIZATION_TYPES = {
# int8 weight + bfloat16/float16 activation
"int8wo": int8_weight_only,
"int8_weight_only": int8_weight_only,
# int8 weight + int8 activation
"int8dq": int8_dynamic_activation_int8_weight,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
}
# TODO(aryan): handle torch 2.2/2.3
FLOATX_QUANTIZATION_TYPES = {
# float8_e5m2 weight + bfloat16/float16 activation
"float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
"float8_weight_only": float8_weight_only,
"float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
# float8_e4m3 weight + bfloat16/float16 activation
"float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
# float8_e5m2 weight + float8 activation (dynamic)
"float8dq": float8_dynamic_activation_float8_weight,
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
# ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out.
# However, changing activation_dtype=torch.float8_e4m3 might work here =====
# "float8dq_e5m2": partial(
# float8_dynamic_activation_float8_weight,
# activation_dtype=torch.float8_e5m2,
# weight_dtype=torch.float8_e5m2,
# ),
# **generate_float8dq_types(torch.float8_e5m2),
# ===== =====
# float8_e4m3 weight + float8 activation (dynamic)
"float8dq_e4m3": partial(
float8_dynamic_activation_float8_weight,
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
),
**generate_float8dq_types(torch.float8_e4m3fn),
# float8 weight + float8 activation (static)
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
# fpx weight + bfloat16/float16 activation
**generate_fpx_quantization_types(3),
**generate_fpx_quantization_types(4),
**generate_fpx_quantization_types(5),
**generate_fpx_quantization_types(6),
**generate_fpx_quantization_types(7),
}
UINTX_QUANTIZATION_DTYPES = {
"uintx_weight_only": uintx_weight_only,
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
"uint2wo": partial(uintx_weight_only, dtype=torch.uint2),
"uint3wo": partial(uintx_weight_only, dtype=torch.uint3),
"uint4wo": partial(uintx_weight_only, dtype=torch.uint4),
"uint5wo": partial(uintx_weight_only, dtype=torch.uint5),
"uint6wo": partial(uintx_weight_only, dtype=torch.uint6),
"uint7wo": partial(uintx_weight_only, dtype=torch.uint7),
# "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
}
QUANTIZATION_TYPES = {}
QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
if cls._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
return QUANTIZATION_TYPES
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
@staticmethod
def _is_cuda_capability_atleast_8_9() -> bool:
if not torch.cuda.is_available():
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
def get_apply_tensor_subclass(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
def __repr__(self):
r"""
Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`:
```
TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "uint_a16w4",
"quant_type_kwargs": {
"group_size": 32
}
}
```
"""
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .torchao_quantizer import TorchAoHfQuantizer
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from
https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/src/transformers/quantizers/quantizer_torchao.py
"""
import importlib
import types
from typing import TYPE_CHECKING, Any, Dict, List, Union
from packaging import version
from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
if is_torch_available():
import torch
import torch.nn as nn
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
# At the moment, only int8 is supported for integer quantization dtypes.
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
# to support more quantization methods, such as intx_weight_only.
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint1,
torch.uint2,
torch.uint3,
torch.uint4,
torch.uint5,
torch.uint6,
torch.uint7,
)
if is_torchao_available():
from torchao.quantization import quantize_
logger = logging.get_logger(__name__)
def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
if isinstance(weight, AffineQuantizedTensor):
return f"{weight.__class__.__name__}({weight._quantization_type()})"
if isinstance(weight, LinearActivationQuantizedTensor):
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
def _linear_extra_repr(self):
weight = _quantization_type(self.weight)
if weight is None:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
else:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
class TorchAoHfQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for TorchAO: https://github.com/pytorch/ao/.
"""
requires_calibration = False
required_packages = ["torchao"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
if not is_torchao_available():
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
self.offload = False
device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
if self.pre_quantized:
raise ValueError(
"You are attempting to perform cpu/disk offload with a pre-quantized torchao model "
"This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
)
else:
self.offload = True
if self.pre_quantized:
weights_only = kwargs.get("weights_only", None)
if weights_only:
torch_version = version.parse(importlib.metadata.version("torch"))
if torch_version < version.parse("2.5.0"):
# TODO(aryan): TorchAO is compatible with Pytorch >= 2.2 for certain quantization types. Try to see if we can support it in future
raise RuntimeError(
f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
)
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
if quant_type.startswith("int"):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)
if torch_dtype is None:
# We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
logger.warning(
"Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` "
"to enable model loading in different precisions. Pass your own `torch_dtype` to specify the "
"dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning."
)
torch_dtype = torch.bfloat16
return torch_dtype
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
if quant_type.startswith("int8") or quant_type.startswith("int4"):
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
return torch.int8
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
return target_dtype
# We need one of the supported dtypes to be selected in order for accelerate to determine
# the total size of modules/parameters for auto device placement.
possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"]
raise ValueError(
f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype "
f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the "
f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
return max_memory
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
param_device = kwargs.pop("param_device", None)
# Check if the param_name is not in self.modules_to_not_convert
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
return False
elif param_device == "cpu" and self.offload:
# We don't quantize weights that we offload
return False
else:
# We only quantize the weight of nn.Linear
module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
):
r"""
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
then we move it to the target device. Finally, we quantize the module.
"""
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
# If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info
# about AffineQuantizedTensor
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module)
else:
# As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
self.modules_to_not_convert.extend(keys_on_cpu)
# Purge `None`.
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
# in case of diffusion transformer models. For language models and others alike, `lm_head`
# and tied modules are usually kept in FP32.
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "ModelMixin"):
return model
def is_serializable(self, safe_serialization=None):
# TODO(aryan): needs to be tested
if safe_serialization:
logger.warning(
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False."
)
return False
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
"0.25.0"
)
if not _is_torchao_serializable:
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
if self.offload and self.quantization_config.modules_to_not_convert is None:
logger.warning(
"The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
"If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
)
return False
return _is_torchao_serializable
@property
def is_trainable(self):
return self.quantization_config.quant_type.startswith("int8")
...@@ -87,6 +87,7 @@ from .import_utils import ( ...@@ -87,6 +87,7 @@ from .import_utils import (
is_torch_version, is_torch_version,
is_torch_xla_available, is_torch_xla_available,
is_torch_xla_version, is_torch_xla_version,
is_torchao_available,
is_torchsde_available, is_torchsde_available,
is_torchvision_available, is_torchvision_available,
is_transformers_available, is_transformers_available,
......
...@@ -340,6 +340,15 @@ if _imageio_available: ...@@ -340,6 +340,15 @@ if _imageio_available:
_imageio_available = False _imageio_available = False
_is_torchao_available = importlib.util.find_spec("torchao") is not None
if _is_torchao_available:
try:
_torchao_version = importlib_metadata.version("torchao")
logger.debug(f"Successfully import torchao version {_torchao_version}")
except importlib_metadata.PackageNotFoundError:
_is_torchao_available = False
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
...@@ -460,6 +469,10 @@ def is_imageio_available(): ...@@ -460,6 +469,10 @@ def is_imageio_available():
return _imageio_available return _imageio_available
def is_torchao_available():
return _is_torchao_available
# docstyle-ignore # docstyle-ignore
FLAX_IMPORT_ERROR = """ FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
...@@ -593,6 +606,11 @@ IMAGEIO_IMPORT_ERROR = """ ...@@ -593,6 +606,11 @@ IMAGEIO_IMPORT_ERROR = """
{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg` {0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
""" """
# docstyle-ignore
TORCHAO_IMPORT_ERROR = """
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao`
"""
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
[ [
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
...@@ -618,6 +636,7 @@ BACKENDS_MAPPING = OrderedDict( ...@@ -618,6 +636,7 @@ BACKENDS_MAPPING = OrderedDict(
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
] ]
) )
......
...@@ -39,6 +39,7 @@ from .import_utils import ( ...@@ -39,6 +39,7 @@ from .import_utils import (
is_timm_available, is_timm_available,
is_torch_available, is_torch_available,
is_torch_version, is_torch_version,
is_torchao_available,
is_torchsde_available, is_torchsde_available,
is_transformers_available, is_transformers_available,
) )
...@@ -476,6 +477,18 @@ def require_bitsandbytes_version_greater(bnb_version): ...@@ -476,6 +477,18 @@ def require_bitsandbytes_version_greater(bnb_version):
return decorator return decorator
def require_torchao_version_greater(torchao_version):
def decorator(test_case):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) > version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case): def deprecate_after_peft_backend(test_case):
""" """
Decorator marking a test that will be skipped after PEFT backend Decorator marking a test that will be skipped after PEFT backend
......
The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/tests/quantization/torchao_integration/).
The benchmarks were run on a single H100. Below is `nvidia-smi`:
```bash
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:53:00.0 Off | 0 |
| N/A 34C P0 69W / 700W | 2MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
```
The benchmark results for Flux and CogVideoX can be found in [this](https://github.com/huggingface/diffusers/pull/10009) PR.
The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent:
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
```
`diffusers-cli`:
```bash
- 🤗 Diffusers version: 0.32.0.dev0
- Platform: Linux-5.15.0-1049-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.6.0.dev20241112+cu121 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.26.2
- Transformers version: 4.46.3
- Accelerate version: 1.1.1
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.4.5
- xFormers version: not installed
```
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment