Unverified Commit 47952192 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`bnb`] Fix bnb skip modules (#24043)

* fix skip modules test

* oops

* address comments
parent a1160185
...@@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non ...@@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
module._parameters[tensor_name] = new_value module._parameters[tensor_name] = new_value
def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): def _replace_with_bnb_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
):
""" """
Private method that wraps the recursion for module replacement. Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not. Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
""" """
has_been_replaced = False
for name, module in model.named_children(): for name, module in model.named_children():
if current_key_name is None: if current_key_name is None:
current_key_name = [] current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert: if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert` # Check if the current key is not in the `modules_to_not_convert`
...@@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam ...@@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam
has_been_replaced = True has_been_replaced = True
# Force requires grad to False to avoid unexpected errors # Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) model._modules[name].requires_grad_(False)
# Remove the last key for recursion
if len(list(module.children())) > 0: if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear( _, has_been_replaced = _replace_with_bnb_linear(
module, module,
modules_to_not_convert, modules_to_not_convert,
current_key_name, current_key_name,
quantization_config, quantization_config,
has_been_replaced=has_been_replaced,
) )
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced return model, has_been_replaced
......
...@@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules: if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.int8) self.assertTrue(module.weight.dtype == torch.int8)
def test_llm_skip(self):
r"""
A simple test to check if `llm_int8_skip_modules` works as expected
"""
import bitsandbytes as bnb
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"])
seq_classification_model = AutoModelForSequenceClassification.from_pretrained(
"roberta-large-mnli", quantization_config=quantization_config
)
self.assertTrue(seq_classification_model.roberta.encoder.layer[0].output.dense.weight.dtype == torch.int8)
self.assertTrue(
isinstance(seq_classification_model.roberta.encoder.layer[0].output.dense, bnb.nn.Linear8bitLt)
)
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
self.assertTrue(seq_classification_model.classifier.dense.weight.dtype != torch.int8)
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear))
self.assertTrue(seq_classification_model.classifier.out_proj != torch.int8)
def test_generate_quality(self): def test_generate_quality(self):
r""" r"""
Test the generation quality of the quantized model and see that we are matching the expected output. Test the generation quality of the quantized model and see that we are matching the expected output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment