"tests/perceiver/test_tokenization_perceiver.py" did not exist on "fecb08c2b8d8003f19147c8c6f5ad6d4df23c710"
Unverified Commit b4c18a83 authored by zhong zhuang's avatar zhong zhuang Committed by GitHub
Browse files

[FEAT]: EETQ quantizer support (#30262)



* [FEAT]: EETQ quantizer support

* Update quantization.md

* Update docs/source/en/main_classes/quantization.md
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/quantization.md
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/quantization.md
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/integrations/__init__.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/integrations/__init__.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/integrations/eetq.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/integrations/eetq.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/integrations/eetq.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/eetq_integration/test_eetq.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/quantizers/auto.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/quantizers/auto.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/quantizers/auto.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_eetq.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/eetq_integration/test_eetq.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_eetq.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/eetq_integration/test_eetq.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/eetq_integration/test_eetq.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* [FEAT]: EETQ quantizer support

* [FEAT]: EETQ quantizer support

* remove whitespaces

* update quantization.md

* style

* Update docs/source/en/quantization.md
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add copyright

* Update quantization.md

* Update docs/source/en/quantization.md
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/quantization.md
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Address the comments by amyeroberts

* style

---------
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: default avatarMarc Sun <marc@huggingface.co>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 569743f5
......@@ -52,6 +52,9 @@ RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoA
# Add quanto for quantization testing
RUN python3 -m pip install --no-cache-dir quanto
# Add eetq for quantization testing
RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git
# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.
RUN cd transformers && python3 setup.py develop
\ No newline at end of file
......@@ -38,6 +38,9 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] AwqConfig
## EetqConfig
[[autodoc]] EetqConfig
## GPTQConfig
[[autodoc]] GPTQConfig
......
......@@ -642,6 +642,37 @@ double_quant_config = BitsAndBytesConfig(
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config)
```
## EETQ
The [EETQ](https://github.com/NetEase-FuXi/EETQ) library supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization.
Make sure you have eetq installed from the [relase page](https://github.com/NetEase-FuXi/EETQ/releases)
```
pip install --no-cache-dir https://github.com/NetEase-FuXi/EETQ/releases/download/v1.0.0/EETQ-1.0.0+cu121+torch2.1.2-cp310-cp310-linux_x86_64.whl
```
or via the source code https://github.com/NetEase-FuXi/EETQ. EETQ requires CUDA capability <= 8.9 and >= 7.0
```
git clone https://github.com/NetEase-FuXi/EETQ.git
cd EETQ/
git submodule update --init --recursive
pip install .
```
An unquantized model can be quantized via "from_pretrained".
```py
from transformers import AutoModelForCausalLM, EetqConfig
path = "/path/to/model"
quantization_config = EetqConfig("int8")
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", quantization_config=quantization_config)
```
A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained".
```py
quant_path = "/path/to/save/quantized/model"
model.save_pretrained(quant_path)
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
```
## Optimum
The [Optimum](https://huggingface.co/docs/optimum/index) library supports quantization for Intel, Furiosa, ONNX Runtime, GPTQ, and lower-level PyTorch quantization functions. Consider using Optimum for quantization if you're using specific and optimized hardware like Intel CPUs, Furiosa NPUs or a model accelerator like ONNX Runtime.
......
......@@ -1126,7 +1126,14 @@ _import_structure = {
"is_vision_available",
"logging",
],
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig", "QuantoConfig"],
"utils.quantization_config": [
"AqlmConfig",
"AwqConfig",
"BitsAndBytesConfig",
"EetqConfig",
"GPTQConfig",
"QuantoConfig",
],
}
# sentencepiece-backed objects
......@@ -6071,7 +6078,14 @@ if TYPE_CHECKING:
)
# bitsandbytes config
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantoConfig
from .utils.quantization_config import (
AqlmConfig,
AwqConfig,
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
QuantoConfig,
)
try:
if not is_sentencepiece_available():
......
......@@ -42,6 +42,7 @@ _import_structure = {
"set_hf_deepspeed_config",
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
"AzureMLCallback",
......@@ -111,6 +112,7 @@ if TYPE_CHECKING:
set_hf_deepspeed_config,
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
AzureMLCallback,
......
# coding=utf-8
# Copyright 2024 NetEase, Inc. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_accelerate_available, is_eetq_available, logging
if is_eetq_available():
import eetq
import torch.nn as nn
if is_accelerate_available():
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
def _replace_with_eetq_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
if current_key_name is None:
current_key_name = []
for name, module in model.named_children():
current_key_name.append(name)
if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
model._modules[name] = eetq.EetqLinear(
in_features, out_features, module.bias is not None, module.weight.device
)
if pre_quantized:
model._modules[name].register_scale(module.weight.device)
has_been_replaced = True
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_eetq_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
pre_quantized=pre_quantized,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_eetq_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
):
"""
A helper function to replace all `torch.nn.Linear` modules by `eetq.EetqLinear` modules from the `eetq`
library. This will enable running your models using high performance int8 weight-only gemm kerner from
FasterTransformer and TensorRT-LLM. Make sure `eetq` compiled with the correct CUDA
version of your hardware is installed before running this function. EETQ shall be installed via the source
'https://github.com/NetEase-FuXi/EETQ'
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_eetq_linear(
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
)
if not has_been_replaced:
logger.warning(
"You are loading your model using eetq but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
......@@ -19,6 +19,7 @@ from ..utils.quantization_config import (
AqlmConfig,
AwqConfig,
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
QuantizationConfigMixin,
QuantizationMethod,
......@@ -28,6 +29,7 @@ from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
......@@ -39,12 +41,14 @@ AUTO_QUANTIZER_MAPPING = {
"gptq": GptqHfQuantizer,
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
"awq": AwqConfig,
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
"eetq": EetqConfig,
"gptq": GPTQConfig,
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
......
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from .base import HfQuantizer
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_accelerate_available, is_eetq_available, is_torch_available, logging
from .quantizers_utils import get_module_from_name
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class EetqHfQuantizer(HfQuantizer):
"""
8-bit quantization from EETQ quantization method:
before loading: converts transformer layers into W8A16Linear during loading: load 16bit weight and pass to the
layer object after: quantizes individual weights in Linear8bitLt into 8bit at first .cuda() call
"""
requires_parameters_quantization = True
requires_calibration = False
required_packages = ["eetq", "accelerate"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_eetq_available():
raise ImportError(
"Using `eetq` 8-bit quantization requires eetq."
"Please install the latest version of eetq from : https://github.com/NetEase-FuXi/EETQ"
)
if not is_accelerate_available():
raise ImportError("Loading an EETQ quantized model requires accelerate (`pip install accelerate`)")
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
raise ValueError(
"Converting into 8-bit weights from tf/flax weights is currently not supported, please make"
" sure the weights are in PyTorch format."
)
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
device_map = kwargs.get("device_map", None)
if device_map is None:
logger.warning_once(
"You have loaded an EETQ model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
elif device_map is not None:
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load an EETQ model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = torch.float16
logger.info(
"Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
"requirements of `eetq` to enable model loading in 8-bit. "
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
" torch_dtype=torch.float16 to remove this warning.",
torch_dtype,
)
elif torch_dtype != torch.float16:
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with EETQ.")
return torch_dtype
def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
from eetq import EetqLinear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, EetqLinear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False
def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
):
"""
quantizes weights into qweight and weight_scales
"""
from eetq import quantize_and_preprocess_weights
module, tensor_name = get_module_from_name(model, param_name)
new_value, weight_scale = quantize_and_preprocess_weights(param_value)
module._buffers[tensor_name] = new_value.to(target_device)
module.register("weight_scales", weight_scale.to(target_device))
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_eetq_linear
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
model = replace_with_eetq_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
pre_quantized=self.pre_quantized,
)
model.config.quantization_config = self.quantization_config
@property
def is_serializable(self):
return True
@property
def is_trainable(self) -> bool:
return False
......@@ -65,6 +65,7 @@ from .utils import (
is_cython_available,
is_decord_available,
is_detectron2_available,
is_eetq_available,
is_essentia_available,
is_faiss_available,
is_flash_attn_2_available,
......@@ -1014,6 +1015,13 @@ def require_aqlm(test_case):
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
def require_eetq(test_case):
"""
Decorator marking a test that requires eetq
"""
return unittest.skipUnless(is_eetq_available(), "test requires eetq")(test_case)
def require_av(test_case):
"""
Decorator marking a test that requires av
......
......@@ -119,6 +119,7 @@ from .import_utils import (
is_datasets_available,
is_decord_available,
is_detectron2_available,
is_eetq_available,
is_essentia_available,
is_faiss_available,
is_flash_attn_2_available,
......
......@@ -97,6 +97,7 @@ _apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
_av_available = importlib.util.find_spec("av") is not None
_bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq")
_galore_torch_available = _is_package_available("galore_torch")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
......@@ -829,6 +830,10 @@ def is_auto_gptq_available():
return _auto_gptq_available
def is_eetq_available():
return _eetq_available
def is_levenshtein_available():
return _levenshtein_available
......
......@@ -40,6 +40,7 @@ class QuantizationMethod(str, Enum):
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
class AWQLinearVersion(str, Enum):
......@@ -893,3 +894,37 @@ class QuantoConfig(QuantizationConfigMixin):
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
if self.activations not in accepted_activations:
raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}")
@dataclass
class EetqConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `eetq`.
Args:
weights (`str`, *optional*, defaults to `"int8"`):
The target dtype for the weights. Supported value is only "int8"
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision.
"""
def __init__(
self,
weights: str = "int8",
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.EETQ
self.weights = weights
self.modules_to_not_convert = modules_to_not_convert
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct
"""
accepted_weights = ["int8"]
if self.weights not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, EetqConfig, OPTForCausalLM
from transformers.testing_utils import (
require_accelerate,
require_eetq,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
@require_torch_gpu
class EetqConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = EetqConfig()
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
def test_from_dict(self):
"""
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
"""
dict = {"modules_to_not_convert": ["lm_head.weight"], "quant_method": "eetq", "weights": "int8"}
quantization_config = EetqConfig.from_dict(dict)
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
self.assertEqual(dict["quant_method"], quantization_config.quant_method)
self.assertEqual(dict["weights"], quantization_config.weights)
@slow
@require_torch_gpu
@require_eetq
@require_accelerate
class EetqTest(unittest.TestCase):
model_name = "facebook/opt-350m"
input_text = "What are we having for dinner?"
max_new_tokens = 9
EXPECTED_OUTPUT = "What are we having for dinner?\nI'm having a steak and a salad"
device_map = "cuda"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
quantization_config = EetqConfig(weights="int8")
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
"""
from eetq import EetqLinear
from transformers.integrations import replace_with_eetq_linear
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
quantization_config = EetqConfig(weights="int8")
with init_empty_weights():
model = OPTForCausalLM(config)
nb_linears = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
nb_linears += 1
model = replace_with_eetq_linear(model, quantization_config=quantization_config)
nb_eetq_linear = 0
for module in model.modules():
if isinstance(module, EetqLinear):
nb_eetq_linear += 1
self.assertEqual(nb_linears - 1, nb_eetq_linear)
# Try with `linear_weights_not_to_quantize`
with init_empty_weights():
model = OPTForCausalLM(config)
quantization_config = EetqConfig(modules_to_not_convert=["fc1"])
model = replace_with_eetq_linear(model, quantization_config=quantization_config)
nb_eetq_linear = 0
for module in model.modules():
if isinstance(module, EetqLinear):
nb_eetq_linear += 1
self.assertEqual(nb_linears - 25, nb_eetq_linear)
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = EetqConfig()
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config
)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment