"docs/vscode:/vscode.git/clone" did not exist on "a6752a7d3c23d03c5d4456c51e992250ebabfc1b"
Unverified Commit 47952192 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`bnb`] Fix bnb skip modules (#24043)

* fix skip modules test

* oops

* address comments
parent a1160185
......@@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
module._parameters[tensor_name] = new_value
def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
def _replace_with_bnb_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
has_been_replaced = False
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
......@@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam
has_been_replaced = True
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# Remove the last key for recursion
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
......
......@@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.int8)
def test_llm_skip(self):
r"""
A simple test to check if `llm_int8_skip_modules` works as expected
"""
import bitsandbytes as bnb
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"])
seq_classification_model = AutoModelForSequenceClassification.from_pretrained(
"roberta-large-mnli", quantization_config=quantization_config
)
self.assertTrue(seq_classification_model.roberta.encoder.layer[0].output.dense.weight.dtype == torch.int8)
self.assertTrue(
isinstance(seq_classification_model.roberta.encoder.layer[0].output.dense, bnb.nn.Linear8bitLt)
)
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
self.assertTrue(seq_classification_model.classifier.dense.weight.dtype != torch.int8)
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear))
self.assertTrue(seq_classification_model.classifier.out_proj != torch.int8)
def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment