Unverified Commit 9f00c617 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] TorchAO Quantizer (#10009)



* torchao quantizer


---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent aafed3f8
......@@ -157,6 +157,8 @@
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/torchao
title: torchao
title: Quantization Methods
- sections:
- local: optimization/fp16
......
......@@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
[[autodoc]] BitsAndBytesConfig
## TorchAoConfig
[[autodoc]] TorchAoConfig
## DiffusersQuantizer
[[autodoc]] quantizers.base.DiffusersQuantizer
......@@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what?
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
\ No newline at end of file
Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use.
\ No newline at end of file
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# torchao
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
```bash
pip install -U torch torchao
```
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
The example below only quantizes the weights to int8.
```python
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=dtype,
)
pipe = FluxPipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=dtype,
)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
image.save("output.png")
```
TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
```python
# In the above code, add the following after initializing the transformer
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
```
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
The `TorchAoConfig` class accepts three parameters:
- `quant_type`: A string value mentioning one of the quantization types below.
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
## Supported quantization types
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
The quantization methods supported are as follows:
| **Category** | **Full Function Names** | **Shorthands** |
|--------------|-------------------------|----------------|
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
## Resources
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)
......@@ -31,7 +31,7 @@ _import_structure = {
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig"],
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
......@@ -569,7 +569,7 @@ else:
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig
from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig
try:
if not is_onnx_available():
......
......@@ -25,7 +25,6 @@ import safetensors
import torch
from huggingface_hub.utils import EntryNotFoundError
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
......@@ -182,7 +181,6 @@ def load_model_dict_into_meta(
device = device or torch.device("cpu")
dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
empty_state_dict = model.state_dict()
......@@ -215,12 +213,12 @@ def load_model_dict_into_meta(
# bnb params are flattened.
if empty_state_dict[param_name].shape != param.shape:
if (
is_quant_method_bnb
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
elif not is_quant_method_bnb:
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
......
......@@ -700,10 +700,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer = None
if hf_quantizer is not None:
if device_map is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
)
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
......@@ -858,13 +860,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if device_map is None and not is_sharded:
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
# It would error out during the `validate_environment()` call above in the absence of cuda.
is_quant_method_bnb = (
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
)
if hf_quantizer is None:
param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb:
else:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
......
......@@ -19,17 +19,20 @@ import warnings
from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
from .torchao import TorchAoHfQuantizer
AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"torchao": TorchAoHfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
"torchao": TorchAoConfig,
}
......
......@@ -22,15 +22,17 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
import copy
import importlib.metadata
import inspect
import json
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Union
from functools import partial
from typing import Any, Dict, List, Optional, Union
from packaging import version
from ..utils import is_torch_available, logging
from ..utils import is_torch_available, is_torchao_available, logging
if is_torch_available():
......@@ -41,6 +43,7 @@ logger = logging.get_logger(__name__)
class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes"
TORCHAO = "torchao"
@dataclass
......@@ -389,3 +392,254 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
serializable_config_dict[key] = value
return serializable_config_dict
@dataclass
class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Args:
quant_type (`str`):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
`int8_weight_only`, `int8_dynamic_activation_int8_weight`
- Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq`
- **Floating point 8-bit quantization:**
- Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`,
`float8_static_activation_float8_weight`
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
`float8_e4m3_tensor`, `float8_e4m3_row`,
- **Floating point X-bit quantization:**
- Full function names: `fpx_weight_only`
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
be satisfied for a given shorthand notation.
- **Unsigned Integer quantization:**
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
kwargs (`Dict[str, Any]`, *optional*):
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
documentation of arguments can be found in
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
Example:
```python
from diffusers import FluxTransformer2DModel, TorchAoConfig
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
```
"""
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None:
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
# When we load from serialized config, "quant_type_kwargs" will be the key
if "quant_type_kwargs" in kwargs:
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
else:
self.quant_type_kwargs = kwargs
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
signature = inspect.signature(method)
all_kwargs = {
param.name
for param in signature.parameters.values()
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
}
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
if len(unsupported_kwargs) > 0:
raise ValueError(
f'The quantization method "{quant_type}" does not support the following keyword arguments: '
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
)
@classmethod
def _get_torchao_quant_type_to_method(cls):
r"""
Returns supported torchao quantization types with all commonly used notations.
"""
if is_torchao_available():
# TODO(aryan): Support autoquant and sparsify
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
from torchao.quantization.observer import PerRow, PerTensor
def generate_float8dq_types(dtype: torch.dtype):
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
types = {}
for granularity_cls in [PerTensor, PerRow]:
# Note: Activation and Weights cannot have different granularities
granularity_name = "tensor" if granularity_cls is PerTensor else "row"
types[f"float8dq_{name}_{granularity_name}"] = partial(
float8_dynamic_activation_float8_weight,
activation_dtype=dtype,
weight_dtype=dtype,
granularity=(granularity_cls(), granularity_cls()),
)
return types
def generate_fpx_quantization_types(bits: int):
types = {}
for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
return types
INT4_QUANTIZATION_TYPES = {
# int4 weight + bfloat16/float16 activation
"int4wo": int4_weight_only,
"int4_weight_only": int4_weight_only,
# int4 weight + int8 activation
"int4dq": int8_dynamic_activation_int4_weight,
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
}
INT8_QUANTIZATION_TYPES = {
# int8 weight + bfloat16/float16 activation
"int8wo": int8_weight_only,
"int8_weight_only": int8_weight_only,
# int8 weight + int8 activation
"int8dq": int8_dynamic_activation_int8_weight,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
}
# TODO(aryan): handle torch 2.2/2.3
FLOATX_QUANTIZATION_TYPES = {
# float8_e5m2 weight + bfloat16/float16 activation
"float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
"float8_weight_only": float8_weight_only,
"float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
# float8_e4m3 weight + bfloat16/float16 activation
"float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
# float8_e5m2 weight + float8 activation (dynamic)
"float8dq": float8_dynamic_activation_float8_weight,
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
# ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out.
# However, changing activation_dtype=torch.float8_e4m3 might work here =====
# "float8dq_e5m2": partial(
# float8_dynamic_activation_float8_weight,
# activation_dtype=torch.float8_e5m2,
# weight_dtype=torch.float8_e5m2,
# ),
# **generate_float8dq_types(torch.float8_e5m2),
# ===== =====
# float8_e4m3 weight + float8 activation (dynamic)
"float8dq_e4m3": partial(
float8_dynamic_activation_float8_weight,
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
),
**generate_float8dq_types(torch.float8_e4m3fn),
# float8 weight + float8 activation (static)
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
# fpx weight + bfloat16/float16 activation
**generate_fpx_quantization_types(3),
**generate_fpx_quantization_types(4),
**generate_fpx_quantization_types(5),
**generate_fpx_quantization_types(6),
**generate_fpx_quantization_types(7),
}
UINTX_QUANTIZATION_DTYPES = {
"uintx_weight_only": uintx_weight_only,
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
"uint2wo": partial(uintx_weight_only, dtype=torch.uint2),
"uint3wo": partial(uintx_weight_only, dtype=torch.uint3),
"uint4wo": partial(uintx_weight_only, dtype=torch.uint4),
"uint5wo": partial(uintx_weight_only, dtype=torch.uint5),
"uint6wo": partial(uintx_weight_only, dtype=torch.uint6),
"uint7wo": partial(uintx_weight_only, dtype=torch.uint7),
# "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
}
QUANTIZATION_TYPES = {}
QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
if cls._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
return QUANTIZATION_TYPES
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
@staticmethod
def _is_cuda_capability_atleast_8_9() -> bool:
if not torch.cuda.is_available():
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
def get_apply_tensor_subclass(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
def __repr__(self):
r"""
Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`:
```
TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "uint_a16w4",
"quant_type_kwargs": {
"group_size": 32
}
}
```
"""
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .torchao_quantizer import TorchAoHfQuantizer
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from
https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/src/transformers/quantizers/quantizer_torchao.py
"""
import importlib
import types
from typing import TYPE_CHECKING, Any, Dict, List, Union
from packaging import version
from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
if is_torch_available():
import torch
import torch.nn as nn
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
# At the moment, only int8 is supported for integer quantization dtypes.
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
# to support more quantization methods, such as intx_weight_only.
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint1,
torch.uint2,
torch.uint3,
torch.uint4,
torch.uint5,
torch.uint6,
torch.uint7,
)
if is_torchao_available():
from torchao.quantization import quantize_
logger = logging.get_logger(__name__)
def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
if isinstance(weight, AffineQuantizedTensor):
return f"{weight.__class__.__name__}({weight._quantization_type()})"
if isinstance(weight, LinearActivationQuantizedTensor):
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
def _linear_extra_repr(self):
weight = _quantization_type(self.weight)
if weight is None:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
else:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
class TorchAoHfQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for TorchAO: https://github.com/pytorch/ao/.
"""
requires_calibration = False
required_packages = ["torchao"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
if not is_torchao_available():
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
self.offload = False
device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
if self.pre_quantized:
raise ValueError(
"You are attempting to perform cpu/disk offload with a pre-quantized torchao model "
"This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
)
else:
self.offload = True
if self.pre_quantized:
weights_only = kwargs.get("weights_only", None)
if weights_only:
torch_version = version.parse(importlib.metadata.version("torch"))
if torch_version < version.parse("2.5.0"):
# TODO(aryan): TorchAO is compatible with Pytorch >= 2.2 for certain quantization types. Try to see if we can support it in future
raise RuntimeError(
f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
)
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
if quant_type.startswith("int"):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)
if torch_dtype is None:
# We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
logger.warning(
"Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` "
"to enable model loading in different precisions. Pass your own `torch_dtype` to specify the "
"dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning."
)
torch_dtype = torch.bfloat16
return torch_dtype
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
if quant_type.startswith("int8") or quant_type.startswith("int4"):
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
return torch.int8
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
return target_dtype
# We need one of the supported dtypes to be selected in order for accelerate to determine
# the total size of modules/parameters for auto device placement.
possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"]
raise ValueError(
f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype "
f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the "
f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
return max_memory
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
param_device = kwargs.pop("param_device", None)
# Check if the param_name is not in self.modules_to_not_convert
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
return False
elif param_device == "cpu" and self.offload:
# We don't quantize weights that we offload
return False
else:
# We only quantize the weight of nn.Linear
module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
):
r"""
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
then we move it to the target device. Finally, we quantize the module.
"""
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
# If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info
# about AffineQuantizedTensor
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module)
else:
# As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
self.modules_to_not_convert.extend(keys_on_cpu)
# Purge `None`.
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
# in case of diffusion transformer models. For language models and others alike, `lm_head`
# and tied modules are usually kept in FP32.
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "ModelMixin"):
return model
def is_serializable(self, safe_serialization=None):
# TODO(aryan): needs to be tested
if safe_serialization:
logger.warning(
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False."
)
return False
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
"0.25.0"
)
if not _is_torchao_serializable:
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
if self.offload and self.quantization_config.modules_to_not_convert is None:
logger.warning(
"The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
"If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
)
return False
return _is_torchao_serializable
@property
def is_trainable(self):
return self.quantization_config.quant_type.startswith("int8")
......@@ -87,6 +87,7 @@ from .import_utils import (
is_torch_version,
is_torch_xla_available,
is_torch_xla_version,
is_torchao_available,
is_torchsde_available,
is_torchvision_available,
is_transformers_available,
......
......@@ -340,6 +340,15 @@ if _imageio_available:
_imageio_available = False
_is_torchao_available = importlib.util.find_spec("torchao") is not None
if _is_torchao_available:
try:
_torchao_version = importlib_metadata.version("torchao")
logger.debug(f"Successfully import torchao version {_torchao_version}")
except importlib_metadata.PackageNotFoundError:
_is_torchao_available = False
def is_torch_available():
return _torch_available
......@@ -460,6 +469,10 @@ def is_imageio_available():
return _imageio_available
def is_torchao_available():
return _is_torchao_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
......@@ -593,6 +606,11 @@ IMAGEIO_IMPORT_ERROR = """
{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
"""
# docstyle-ignore
TORCHAO_IMPORT_ERROR = """
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao`
"""
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
......@@ -618,6 +636,7 @@ BACKENDS_MAPPING = OrderedDict(
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
]
)
......
......@@ -39,6 +39,7 @@ from .import_utils import (
is_timm_available,
is_torch_available,
is_torch_version,
is_torchao_available,
is_torchsde_available,
is_transformers_available,
)
......@@ -476,6 +477,18 @@ def require_bitsandbytes_version_greater(bnb_version):
return decorator
def require_torchao_version_greater(torchao_version):
def decorator(test_case):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) > version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
......
The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/tests/quantization/torchao_integration/).
The benchmarks were run on a single H100. Below is `nvidia-smi`:
```bash
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:53:00.0 Off | 0 |
| N/A 34C P0 69W / 700W | 2MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
```
The benchmark results for Flux and CogVideoX can be found in [this](https://github.com/huggingface/diffusers/pull/10009) PR.
The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent:
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
```
`diffusers-cli`:
```bash
- 🤗 Diffusers version: 0.32.0.dev0
- Platform: Linux-5.15.0-1049-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.6.0.dev20241112+cu121 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.26.2
- Transformers version: 4.46.3
- Accelerate version: 1.1.1
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.4.5
- xFormers version: not installed
```
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from typing import List
import numpy as np
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxPipeline,
FluxTransformer2DModel,
TorchAoConfig,
)
from diffusers.models.attention_processor import Attention
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_available,
is_torchao_available,
nightly,
require_torch,
require_torch_gpu,
require_torchao_version_greater,
slow,
torch_device,
)
enable_full_determinism()
if is_torch_available():
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
if is_torchao_available():
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Makes sure the config format is properly set
"""
quantization_config = TorchAoConfig("int4_weight_only")
torchao_orig_config = quantization_config.to_dict()
for key in torchao_orig_config:
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])
def test_post_init_check(self):
"""
Test kwargs validations in TorchAoConfig
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"):
_ = TorchAoConfig("uint8")
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
_ = TorchAoConfig("int4_weight_only", group_size1=32)
def test_repr(self):
"""
Check that there is no error in the repr
"""
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
expected_repr = """TorchAoConfig {
"modules_to_not_convert": [
"conv"
],
"quant_method": "torchao",
"quant_type": "int4_weight_only",
"quant_type_kwargs": {
"group_size": 8
}
}""".replace(" ", "").replace("\n", "")
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
self.assertEqual(quantization_repr, expected_repr)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_components(self, quantization_config: TorchAoConfig):
model_id = "hf-internal-testing/tiny-flux-pipe"
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator().manual_seed(seed)
inputs = {
"prompt": "an astronaut riding a horse in space",
"height": 32,
"width": 32,
"num_inference_steps": 2,
"output_type": "np",
"generator": generator,
}
return inputs
def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]):
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device, dtype=torch.bfloat16)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
output_slice = output[-1, -1, -3:, -3:].flatten()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_quantization(self):
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
("int_a8w8", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint_a16w7", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]
if TorchAoConfig._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
# =====
# The following lead to an internal torch error:
# RuntimeError: mat2 shape (32x4 must be divisible by 16
# Skip these for now; TODO(aryan): investigate later
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quant_kwargs = {}
if quantization_name in ["uint4wo", "uint_a16w7"]:
# The dummy flux model that we use requires us to impose some restrictions on group_size here
quant_kwargs.update({"group_size": 16})
quantization_config = TorchAoConfig(
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
)
self._test_quant_type(quantization_config, expected_slice)
def test_int4wo_quant_bfloat16_conversion(self):
"""
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
"""
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
self.assertEqual(weight.quant_min, 0)
self.assertEqual(weight.quant_max, 15)
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
def test_offload(self):
"""
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
that the device map is correctly set (in the `hf_device_map` attribute of the model).
"""
device_map_offload = {
"time_text_embed": torch_device,
"context_embedder": torch_device,
"x_embedder": torch_device,
"transformer_blocks.0": "cpu",
"single_transformer_blocks.0": "disk",
"norm_out": torch_device,
"proj_out": "cpu",
}
inputs = self.get_dummy_tensor_inputs(torch_device)
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map_offload,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
self.assertTrue(quantized_model.hf_device_map == device_map_offload)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2]
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
quantized_layer = quantized_model.proj_out
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8)
def test_training(self):
quantization_config = TorchAoConfig("int8_weight_only")
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for param in quantized_model.parameters():
# freeze the model as only adapter layers will be trained
param.requires_grad = False
if param.ndim == 1:
param.data = param.data.to(torch.float32)
for _, module in quantized_model.named_modules():
if isinstance(module, Attention):
module.to_q = LoRALayer(module.to_q, rank=4)
module.to_k = LoRALayer(module.to_k, rank=4)
module.to_v = LoRALayer(module.to_v, rank=4)
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0]
output.norm().backward()
for module in quantized_model.modules():
if isinstance(module, LoRALayer):
self.assertTrue(module.adapter[1].weight.grad is not None)
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
@nightly
def test_torch_compile(self):
r"""Test that verifies if torch.compile works with torchao quantization."""
quantization_config = TorchAoConfig("int8_weight_only")
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device, dtype=torch.bfloat16)
inputs = self.get_dummy_inputs(torch_device)
normal_output = pipe(**inputs)[0].flatten()[-32:]
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
inputs = self.get_dummy_inputs(torch_device)
compile_output = pipe(**inputs)[0].flatten()[-32:]
# Note: Seems to require higher tolerance
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
@staticmethod
def _get_memory_footprint(module):
quantized_param_memory = 0.0
unquantized_param_memory = 0.0
for param in module.parameters():
if param.__class__.__name__ == "AffineQuantizedTensor":
data, scale, zero_point = param.layout_tensor.get_plain()
quantized_param_memory += data.numel() + data.element_size()
quantized_param_memory += scale.numel() + scale.element_size()
quantized_param_memory += zero_point.numel() + zero_point.element_size()
else:
unquantized_param_memory += param.data.numel() * param.data.element_size()
total_memory = quantized_param_memory + unquantized_param_memory
return total_memory, quantized_param_memory, unquantized_param_memory
def test_memory_footprint(self):
r"""
A simple test to check if the model conversion has been done correctly by checking on the
memory footprint of the converted model and the class type of the linear layers of the converted models
"""
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"]
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
transformer_bf16 = self.get_dummy_components(None)["transformer"]
total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo)
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint(
transformer_int4wo_gs32
)
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo)
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16)
self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16)
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32)
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32)
# int8 quantizes more layers compare to int4 with default group size
self.assertTrue(quantized_int8wo < quantized_int4wo)
def test_wrong_config(self):
with self.assertRaises(ValueError):
self.get_dummy_components(TorchAoConfig("int42"))
# This class is not to be run as a test by itself. See the tests that follow this class
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
quant_method, quant_method_kwargs = None, None
device = "cuda"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_model(self, device=None):
quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs)
quantized_model = FluxTransformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
return quantized_model.to(device)
def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
def test_original_model_expected_slice(self):
quantized_model = self.get_dummy_model(torch_device)
inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3))
def check_serialization_expected_slice(self, expected_slice):
quantized_model = self.get_dummy_model(self.device)
with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
loaded_quantized_model = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False
)
inputs = self.get_dummy_tensor_inputs(torch_device)
output = loaded_quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(
isinstance(
loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)
)
)
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_serialization_expected_slice(self):
self.check_serialization_expected_slice(self.serialized_expected_slice)
class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice
device = "cuda"
class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice
device = "cuda"
class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
serialized_expected_slice = expected_slice
device = "cpu"
class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
serialized_expected_slice = expected_slice
device = "cpu"
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_components(self, quantization_config: TorchAoConfig):
model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator().manual_seed(seed)
inputs = {
"prompt": "an astronaut riding a horse in space",
"height": 512,
"width": 512,
"num_inference_steps": 20,
"output_type": "np",
"generator": generator,
}
return inputs
def _test_quant_type(self, quantization_config, expected_slice):
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components).to(dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0].flatten()
output_slice = np.concatenate((output[:16], output[-16:]))
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_quantization(self):
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
]
if TorchAoConfig._is_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
self._test_quant_type(quantization_config, expected_slice)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment