Unverified Commit 3f435823 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FEAT / Bitsandbytes: Add `dequantize` API for bitsandbytes quantized models (#30806)



* add  method

* change method name

* more comments

* Apply suggestions from code review
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fixup

* add docstrings and fix comment

* warn users on the de-quantized dtype

* Update src/transformers/quantizers/base.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/integrations/bitsandbytes.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* final suggestion - use private method

---------
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 58faa7b8
......@@ -642,6 +642,27 @@ double_quant_config = BitsAndBytesConfig(
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config)
```
### Dequantizing `bitsandbytes` models
Once quantized, you can dequantize the model to the original precision. Note this might result in a small quality loss of the model. Make also sure to have enough GPU RAM to fit the dequantized model.
Below is how to perform dequantization on a 4-bit model using `bitsandbytes`.
```python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id, BitsAndBytesConfig(load_in_4bit=True))
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.dequantize()
text = tokenizer("Hello my name is", return_tensors="pt").to(0)
out = model.generate(**text)
print(tokenizer.decode(out[0]))
```
## EETQ
The [EETQ](https://github.com/NetEase-FuXi/EETQ) library supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization.
......
......@@ -25,6 +25,7 @@ _import_structure = {
"replace_with_awq_linear",
],
"bitsandbytes": [
"dequantize_and_replace",
"get_keys_to_not_convert",
"replace_8bit_linear",
"replace_with_bnb_linear",
......@@ -105,6 +106,7 @@ if TYPE_CHECKING:
replace_with_awq_linear,
)
from .bitsandbytes import (
dequantize_and_replace,
get_keys_to_not_convert,
replace_8bit_linear,
replace_with_bnb_linear,
......
import importlib.metadata
import inspect
import warnings
from copy import deepcopy
from inspect import signature
......@@ -16,7 +17,9 @@ if is_bitsandbytes_available():
from ..pytorch_utils import Conv1D
if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import find_tied_parameters
logger = logging.get_logger(__name__)
......@@ -322,3 +325,141 @@ def get_keys_to_not_convert(model):
filtered_module_names.append(name)
return filtered_module_names
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
"""
Helper function to dequantize 4bit or 8bit bnb weights.
If the weight is not a bnb quantized weight, it will be returned as is.
"""
if not isinstance(weight, torch.nn.Parameter):
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
cls_name = weight.__class__.__name__
if cls_name not in ("Params4bit", "Int8Params"):
return weight
if cls_name == "Params4bit":
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
logger.warning_once(
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
)
return output_tensor
if state.SCB is None:
state.SCB = weight.SCB
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
def _create_accelerate_new_hook(old_hook):
r"""
Creates a new hook based on the old hook. Use it only if you know what you are doing !
This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245
with some changes
"""
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
old_hook_attr = old_hook.__dict__
filtered_old_hook_attr = {}
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
for k in old_hook_attr.keys():
if k in old_hook_init_signature.parameters:
filtered_old_hook_attr[k] = old_hook_attr[k]
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook
def _dequantize_and_replace(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
"""
Converts a quantized model into its dequantized original version. The newly converted model will have
some performance drop compared to the original model before quantization - use it only for specific usecases
such as QLoRA adapters merging.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
quant_method = quantization_config.quantization_method()
target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, target_cls) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
bias = getattr(module, "bias", None)
device = module.weight.device
with init_empty_weights():
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
if quant_method == "llm_int8":
state = module.state
else:
state = None
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
if bias is not None:
new_module.bias = bias
# Create a new hook and attach it in case we use accelerate
if hasattr(module, "_hf_hook"):
old_hook = module._hf_hook
new_hook = _create_accelerate_new_hook(old_hook)
remove_hook_from_module(module)
add_hook_to_module(new_module, new_hook)
new_module.to(device)
model._modules[name] = new_module
if len(list(module.children())) > 0:
_, has_been_replaced = _dequantize_and_replace(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def dequantize_and_replace(
model,
modules_to_not_convert=None,
quantization_config=None,
):
model, has_been_replaced = _dequantize_and_replace(
model,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
if not has_been_replaced:
logger.warning(
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
)
return model
......@@ -1327,6 +1327,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
def dequantize(self):
"""
Potentially dequantize the model in case it has been quantized by a quantization method that support
dequantization.
"""
hf_quantizer = getattr(self, "hf_quantizer", None)
if hf_quantizer is None:
raise ValueError("You need to first quantize your model in order to dequantize it")
return hf_quantizer.dequantize(self)
def _backward_compatibility_gradient_checkpointing(self):
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
......
......@@ -194,6 +194,23 @@ class HfQuantizer(ABC):
"""
return self._process_model_after_weight_loading(model, **kwargs)
def dequantize(self, model):
"""
Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance.
Note not all quantization schemes support this.
"""
model = self._dequantize(model)
# Delete quantizer and quantization config
del model.hf_quantizer
return model
def _dequantize(self, model):
raise NotImplementedError(
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)
@abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs):
...
......
......@@ -312,3 +312,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
@property
def is_trainable(self) -> bool:
return True
def _dequantize(self, model):
from ..integrations import dequantize_and_replace
model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model
......@@ -281,3 +281,11 @@ class Bnb8BitHfQuantizer(HfQuantizer):
@property
def is_trainable(self) -> bool:
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0")
def _dequantize(self, model):
from ..integrations import dequantize_and_replace
model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model
......@@ -239,6 +239,23 @@ class Bnb4BitTest(Base4bitTest):
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_generate_quality_dequantize(self):
r"""
Test that loading the model and unquantize it produce correct results
"""
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)
model_4bit.dequantize()
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
......
......@@ -285,6 +285,23 @@ class MixedInt8Test(BaseMixedInt8Test):
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_generate_quality_dequantize(self):
r"""
Test that loading the model and dequantizing it produce correct results
"""
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)
model_8bit.dequantize()
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_raise_if_config_and_load_in_8bit(self):
r"""
Test that loading the model with the config and `load_in_8bit` raises an error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment