"...sharings/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "611ed639bb2ecd1db5a0241833c0d0a1e7812b1d"
Unverified Commit 96a074fa authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Add new quant method (#32047)

* Add new quant method

* update

* fix multi-device

* add test

* add offload

* style

* style

* add simple example

* initial doc

* docstring

* style again

* works ?

* better docs

* switch to non persistant

* remove print

* fix init

* code review
parent bd9dca3b
...@@ -157,6 +157,8 @@ ...@@ -157,6 +157,8 @@
title: EETQ title: EETQ
- local: quantization/hqq - local: quantization/hqq
title: HQQ title: HQQ
- local: quantization/fbgemm_fp8
title: FBGEMM_FP8
- local: quantization/optimum - local: quantization/optimum
title: Optimum title: Optimum
- local: quantization/contribute - local: quantization/contribute
......
...@@ -56,3 +56,8 @@ Learn how to quantize models in the [Quantization](../quantization) guide. ...@@ -56,3 +56,8 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## HqqConfig ## HqqConfig
[[autodoc]] HqqConfig [[autodoc]] HqqConfig
## FbgemmFp8Config
[[autodoc]] FbgemmFp8Config
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# FBGEMM FP8
With FBGEMM FP8 quantization method, you can quantize your model in FP8 (W8A8):
- the weights will be quantized in 8bit (FP8) per channel
- the activation will be quantized in 8bit (FP8) per token
It relies on the [FBGEMM](https://github.com/pytorch/FBGEMM) library which provides efficient low-precision general matrix multiplication for small batch sizes and support for accuracy-loss minimizing techniques such as row-wise quantization and outlier-aware quantization.
> [!TIP]
> You need a GPU with compute capability>=9 (e.g. H100)
Before you begin, make sure the following libraries are installed with their latest version:
```bash
pip install --upgrade accelerate fbgemm-gpu torch
```
If you are having issues with fbgemm-gpu and torch library, you might need to install the nighlty release. You can follow the instruction [here](https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries:~:text=found%20here.-,Install%20the%20FBGEMM_GPU%20Package,-Install%20through%20PyTorch)
```py
from transformers import FbgemmFp8Config, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = FbgemmFp8Config()
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained".
```py
quant_path = "/path/to/save/quantized/model"
model.save_pretrained(quant_path)
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
```
\ No newline at end of file
...@@ -55,4 +55,5 @@ Use the table below to help you decide which quantization method to use. ...@@ -55,4 +55,5 @@ Use the table below to help you decide which quantization method to use.
| [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ | | [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ | | [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto | | [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
...@@ -934,6 +934,7 @@ _import_structure = { ...@@ -934,6 +934,7 @@ _import_structure = {
"AwqConfig", "AwqConfig",
"BitsAndBytesConfig", "BitsAndBytesConfig",
"EetqConfig", "EetqConfig",
"FbgemmFp8Config",
"GPTQConfig", "GPTQConfig",
"HqqConfig", "HqqConfig",
"QuantoConfig", "QuantoConfig",
...@@ -5665,6 +5666,7 @@ if TYPE_CHECKING: ...@@ -5665,6 +5666,7 @@ if TYPE_CHECKING:
AwqConfig, AwqConfig,
BitsAndBytesConfig, BitsAndBytesConfig,
EetqConfig, EetqConfig,
FbgemmFp8Config,
GPTQConfig, GPTQConfig,
HqqConfig, HqqConfig,
QuantoConfig, QuantoConfig,
......
...@@ -45,6 +45,7 @@ _import_structure = { ...@@ -45,6 +45,7 @@ _import_structure = {
"unset_hf_deepspeed_config", "unset_hf_deepspeed_config",
], ],
"eetq": ["replace_with_eetq_linear"], "eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
"ggml": [ "ggml": [
"GGUF_CONFIG_MAPPING", "GGUF_CONFIG_MAPPING",
"GGUF_TENSOR_MAPPING", "GGUF_TENSOR_MAPPING",
...@@ -126,6 +127,7 @@ if TYPE_CHECKING: ...@@ -126,6 +127,7 @@ if TYPE_CHECKING:
unset_hf_deepspeed_config, unset_hf_deepspeed_config,
) )
from .eetq import replace_with_eetq_linear from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
from .ggml import ( from .ggml import (
GGUF_CONFIG_MAPPING, GGUF_CONFIG_MAPPING,
GGUF_TENSOR_MAPPING, GGUF_TENSOR_MAPPING,
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
if is_torch_available():
import torch
from torch import nn
if is_accelerate_available():
from accelerate import init_empty_weights
if is_fbgemm_gpu_available():
import fbgemm_gpu.experimental.gen_ai # noqa: F401
logger = logging.get_logger(__name__)
class FbgemmFp8Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
if bias:
self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
else:
self.bias = None
def forward(self, x):
num_tokens = None
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
output = torch.ops.fbgemm.f8f8bf16_rowwise(
x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
)
output = output + self.bias if self.bias is not None else output
# Hacky for now, we have the output to the device of x
output = output.to(x.device)
del x_quantized, x_scale
return output
def _replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
if current_key_name is None:
current_key_name = []
for name, module in model.named_children():
current_key_name.append(name)
if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights(include_buffers=True):
in_features = module.in_features
out_features = module.out_features
model._modules[name] = FbgemmFp8Linear(
in_features,
out_features,
module.bias is not None,
)
has_been_replaced = True
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# set non persistant buffer outside of init_empty_weights
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub], dtype=torch.float
)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
pre_quantized=pre_quantized,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
):
"""
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
This will enable running your models using high performance fp8 kernel from FBGEMM library.
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `FP8Linear`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
)
if not has_been_replaced:
logger.warning(
"You are loading your model using FP8 quantization but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
...@@ -868,7 +868,7 @@ def _load_state_dict_into_meta_model( ...@@ -868,7 +868,7 @@ def _load_state_dict_into_meta_model(
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them. # in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param): if dtype is not None and torch.is_floating_point(param) and param.dtype != torch.float8_e4m3fn:
if ( if (
keep_in_fp32_modules is not None keep_in_fp32_modules is not None
and any( and any(
...@@ -894,7 +894,6 @@ def _load_state_dict_into_meta_model( ...@@ -894,7 +894,6 @@ def _load_state_dict_into_meta_model(
old_param = getattr(old_param, split) old_param = getattr(old_param, split)
if old_param is None: if old_param is None:
break break
if old_param is not None: if old_param is not None:
if dtype is None: if dtype is None:
param = param.to(old_param.dtype) param = param.to(old_param.dtype)
...@@ -3955,6 +3954,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3955,6 +3954,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
): ):
device_map_kwargs["force_hooks"] = True device_map_kwargs["force_hooks"] = True
if (
hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
device_map_kwargs["offload_buffers"] = True
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs) dispatch_model(model, **device_map_kwargs)
...@@ -4105,7 +4112,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -4105,7 +4112,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if cls._keys_to_ignore_on_load_unexpected is not None: if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected: for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if hf_quantizer is not None: if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
......
...@@ -20,6 +20,7 @@ from ..utils.quantization_config import ( ...@@ -20,6 +20,7 @@ from ..utils.quantization_config import (
AwqConfig, AwqConfig,
BitsAndBytesConfig, BitsAndBytesConfig,
EetqConfig, EetqConfig,
FbgemmFp8Config,
GPTQConfig, GPTQConfig,
HqqConfig, HqqConfig,
QuantizationConfigMixin, QuantizationConfigMixin,
...@@ -31,6 +32,7 @@ from .quantizer_awq import AwqQuantizer ...@@ -31,6 +32,7 @@ from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_eetq import EetqHfQuantizer from .quantizer_eetq import EetqHfQuantizer
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_gptq import GptqHfQuantizer from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HqqHfQuantizer from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer from .quantizer_quanto import QuantoHfQuantizer
...@@ -45,6 +47,7 @@ AUTO_QUANTIZER_MAPPING = { ...@@ -45,6 +47,7 @@ AUTO_QUANTIZER_MAPPING = {
"quanto": QuantoHfQuantizer, "quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer, "eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer, "hqq": HqqHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
} }
AUTO_QUANTIZATION_CONFIG_MAPPING = { AUTO_QUANTIZATION_CONFIG_MAPPING = {
...@@ -56,6 +59,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = { ...@@ -56,6 +59,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"aqlm": AqlmConfig, "aqlm": AqlmConfig,
"quanto": QuantoConfig, "quanto": QuantoConfig,
"hqq": HqqConfig, "hqq": HqqConfig,
"fbgemm_fp8": FbgemmFp8Config,
} }
...@@ -156,8 +160,11 @@ class AutoHfQuantizer: ...@@ -156,8 +160,11 @@ class AutoHfQuantizer:
if isinstance(quantization_config, dict): if isinstance(quantization_config, dict):
quantization_config = AutoQuantizationConfig.from_dict(quantization_config) quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
if isinstance(quantization_config, (GPTQConfig, AwqConfig)) and quantization_config_from_args is not None: if (
# special case for GPTQ / AWQ config collision isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config))
and quantization_config_from_args is not None
):
# special case for GPTQ / AWQ / FbgemmFp8 config collision
loading_attr_dict = quantization_config_from_args.get_loading_attributes() loading_attr_dict = quantization_config_from_args.get_loading_attributes()
for attr, val in loading_attr_dict.items(): for attr, val in loading_attr_dict.items():
setattr(quantization_config, attr, val) setattr(quantization_config, attr, val)
......
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from packaging import version
from .base import HfQuantizer
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
from .quantizers_utils import get_module_from_name
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class FbgemmFp8HfQuantizer(HfQuantizer):
"""
FP8 quantization using fbgemm kernels
"""
requires_parameters_quantization = True
requires_calibration = False
required_packages = ["fbgemm-gpu", "accelerate"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_torch_available() or version.parse(importlib.metadata.version("torch")) < version.parse("2.1.0"):
raise ImportError(
"Using fbgemm fp8 quantization requires torch > 2.1.0"
"Please install the latest version of torch ( pip install --upgrade torch )"
)
if not is_fbgemm_gpu_available():
raise ImportError(
"Using fbgemm fp8 quantization requires fbgemm-gpu library"
"Please install the latest version of fbgemm-gpu library by following : https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries"
)
if not is_accelerate_available("0.32.2"):
raise ImportError(
"Loading an FP8 quantized model requires accelerate > 0.32.1 (`pip install --upgrade accelerate`)"
)
if not torch.cuda.is_available():
raise RuntimeError("Using FP8 quantized models with fbgemm kernels requires a GPU")
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
if major < 9:
raise ValueError(
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
)
device_map = kwargs.get("device_map", None)
if device_map is None:
logger.warning_once(
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
)
elif device_map is not None:
if (
not self.pre_quantized
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to load an FP8 model with a device_map that contains a CPU or disk device."
"This is not supported when the model is quantized on the fly. "
"Please use a quantized checkpoint or remove the CPU or disk device from the device_map."
)
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = torch.bfloat16
logger.info(
"Overriding torch_dtype=%s with `torch_dtype=torch.bloat16` due to "
"requirements of `fbgemm-gpu` to enable model loading in fp8. "
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
" torch_dtype=torch.bfloat16 to remove this warning.",
torch_dtype,
)
elif torch_dtype == torch.float16:
raise ValueError(
"You cannot use FP8 with torch_dtype=torch.float16."
"We recommend you passing torch_dtype=torch.bfloat16"
)
return torch_dtype
def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
from ..integrations import FbgemmFp8Linear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FbgemmFp8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False
def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
):
"""
Quantizes weights into weight and weight_scale
"""
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
module, tensor_name = get_module_from_name(model, param_name)
module._buffers[tensor_name] = new_value.to(target_device)
# to have the right output shape -> (out_features, 1)
module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device)
if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)
del param_name
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_fbgemm_fp8_linear
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
model = replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
pre_quantized=self.pre_quantized,
)
model.config.quantization_config = self.quantization_config
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import FbgemmFp8Linear
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, FbgemmFp8Linear):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
@property
def is_serializable(self):
return True
@property
def is_trainable(self) -> bool:
return False
...@@ -68,6 +68,7 @@ from .utils import ( ...@@ -68,6 +68,7 @@ from .utils import (
is_eetq_available, is_eetq_available,
is_essentia_available, is_essentia_available,
is_faiss_available, is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flax_available, is_flax_available,
is_fsdp_available, is_fsdp_available,
...@@ -1116,6 +1117,13 @@ def require_quanto(test_case): ...@@ -1116,6 +1117,13 @@ def require_quanto(test_case):
return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case) return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case)
def require_fbgemm_gpu(test_case):
"""
Decorator for fbgemm_gpu dependency
"""
return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
def require_phonemizer(test_case): def require_phonemizer(test_case):
""" """
Decorator marking a test that requires phonemizer Decorator marking a test that requires phonemizer
......
...@@ -127,6 +127,7 @@ from .import_utils import ( ...@@ -127,6 +127,7 @@ from .import_utils import (
is_eetq_available, is_eetq_available,
is_essentia_available, is_essentia_available,
is_faiss_available, is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10, is_flash_attn_greater_or_equal_2_10,
......
...@@ -98,6 +98,7 @@ _aqlm_available = _is_package_available("aqlm") ...@@ -98,6 +98,7 @@ _aqlm_available = _is_package_available("aqlm")
_av_available = importlib.util.find_spec("av") is not None _av_available = importlib.util.find_spec("av") is not None
_bitsandbytes_available = _is_package_available("bitsandbytes") _bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq") _eetq_available = _is_package_available("eetq")
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
_galore_torch_available = _is_package_available("galore_torch") _galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim") _lomo_available = _is_package_available("lomo_optim")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
...@@ -888,6 +889,10 @@ def is_eetq_available(): ...@@ -888,6 +889,10 @@ def is_eetq_available():
return _eetq_available return _eetq_available
def is_fbgemm_gpu_available():
return _fbgemm_gpu_available
def is_levenshtein_available(): def is_levenshtein_available():
return _levenshtein_available return _levenshtein_available
......
...@@ -42,6 +42,7 @@ class QuantizationMethod(str, Enum): ...@@ -42,6 +42,7 @@ class QuantizationMethod(str, Enum):
QUANTO = "quanto" QUANTO = "quanto"
EETQ = "eetq" EETQ = "eetq"
HQQ = "hqq" HQQ = "hqq"
FBGEMM_FP8 = "fbgemm_fp8"
class AWQLinearVersion(str, Enum): class AWQLinearVersion(str, Enum):
...@@ -1047,3 +1048,34 @@ class EetqConfig(QuantizationConfigMixin): ...@@ -1047,3 +1048,34 @@ class EetqConfig(QuantizationConfigMixin):
accepted_weights = ["int8"] accepted_weights = ["int8"]
if self.weights not in accepted_weights: if self.weights not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}") raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
@dataclass
class FbgemmFp8Config(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using fbgemm fp8 quantization.
Args:
activation_scale_ub (`float`, *optional*, defaults to 1200.0):
The activation scale upper bound. This is used when quantizing the input activation.
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision.
"""
def __init__(
self,
activation_scale_ub: float = 1200.0,
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.FBGEMM_FP8
self.activation_scale_ub = activation_scale_ub
self.modules_to_not_convert = modules_to_not_convert
def get_loading_attributes(self):
attibutes_dict = copy.deepcopy(self.__dict__)
loading_attibutes = ["activation_scale_ub"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FbgemmFp8Config, OPTForCausalLM
from transformers.testing_utils import (
require_accelerate,
require_fbgemm_gpu,
require_read_token,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
@require_torch_gpu
class FbgemmFp8ConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = FbgemmFp8Config()
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
def test_from_dict(self):
"""
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
"""
dict = {"modules_to_not_convert": ["lm_head.weight"], "quant_method": "fbgemm_fp8"}
quantization_config = FbgemmFp8Config.from_dict(dict)
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
self.assertEqual(dict["quant_method"], quantization_config.quant_method)
@slow
@require_torch_gpu
@require_fbgemm_gpu
@require_accelerate
@require_read_token
class FbgemmFp8Test(unittest.TestCase):
model_name = "meta-llama/Meta-Llama-3-8B"
input_text = "What are we having for dinner?"
max_new_tokens = 9
EXPECTED_OUTPUT = "What are we having for dinner?\nI'm having a steak and a salad"
device_map = "cuda"
offload_device_map = {
"model.embed_tokens": 0,
"model.layers.0": 0,
"model.layers.1": 0,
"model.layers.2": 0,
"model.layers.3": 0,
"model.layers.4": 0,
"model.layers.5": 0,
"model.layers.6": 0,
"model.layers.7": 0,
"model.layers.8": 0,
"model.layers.9": 0,
"model.layers.10": 0,
"model.layers.11": 0,
"model.layers.12": 0,
"model.layers.13": 0,
"model.layers.14": 0,
"model.layers.15": 0,
"model.layers.16": "cpu",
"model.layers.17": "cpu",
"model.layers.18": "cpu",
"model.layers.19": "cpu",
"model.layers.20": "disk",
"model.layers.21": "disk",
"model.layers.22": "disk",
"model.layers.23": "disk",
"model.layers.24": "disk",
"model.layers.25": "disk",
"model.layers.26": "disk",
"model.layers.27": "disk",
"model.layers.28": "disk",
"model.layers.29": "disk",
"model.layers.30": "disk",
"model.layers.31": "disk",
"model.norm": "disk",
"lm_head": "disk",
}
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
quantization_config = FbgemmFp8Config()
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
"""
from transformers.integrations import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
quantization_config = FbgemmFp8Config()
with init_empty_weights():
model = OPTForCausalLM(config)
nb_linears = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
nb_linears += 1
model = replace_with_fbgemm_fp8_linear(model, quantization_config=quantization_config)
nb_fbgemm_linear = 0
for module in model.modules():
if isinstance(module, FbgemmFp8Linear):
nb_fbgemm_linear += 1
self.assertEqual(nb_linears - 1, nb_fbgemm_linear)
with init_empty_weights():
model = OPTForCausalLM(config)
quantization_config = FbgemmFp8Config(modules_to_not_convert=["fc1"])
model = replace_with_fbgemm_fp8_linear(model, quantization_config=quantization_config)
nb_fbgemm_linear = 0
for module in model.modules():
if isinstance(module, FbgemmFp8Linear):
nb_fbgemm_linear += 1
self.assertEqual(nb_linears - 25, nb_fbgemm_linear)
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_change_loading_attributes(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
quantization_config = FbgemmFp8Config(activation_scale_ub=1000.0)
model = AutoModelForCausalLM.from_pretrained(
tmpdirname, device_map=self.device_map, quantization_config=quantization_config
)
self.assertEqual(model.model.layers[1].mlp.down_proj.input_scale_ub.item(), 1000.0)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = FbgemmFp8Config()
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config
)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_quantized_model_offload(self):
"""
Simple test that checks if the quantized model returns an error when loading with cpu/disk offloaded
"""
quantization_config = FbgemmFp8Config()
with self.assertRaisesRegex(
ValueError, "You are attempting to load an FP8 model with a device_map that contains a CPU or disk device."
):
AutoModelForCausalLM.from_pretrained(
self.model_name, device_map=self.offload_device_map, quantization_config=quantization_config
)
def test_save_pretrained_offload(self):
"""
Simple test that checks if the saved quantized model is working properly cpu/disk offload
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_save_pretrained_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
self.assertTrue(set(model.hf_device_map.values()) == {0, 1})
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment