"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e174bfeb340d3d3468d9c8eebce95c42aa2dcf84"
Unverified Commit 12240925 authored by Dario Sučić's avatar Dario Sučić Committed by GitHub
Browse files

Add bitsandbytes support for gpt2 models (#24504)



* Add bitsandbytes support for gpt2 models

* Guard Conv1D import to pass tensorflow test

* Appease ruff linter

* Fix 4bit test and remove int8 test boilerplate

* Update tests/bnb/test_mixed_int8.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 89b6ee49
...@@ -12,6 +12,8 @@ if is_bitsandbytes_available(): ...@@ -12,6 +12,8 @@ if is_bitsandbytes_available():
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..pytorch_utils import Conv1D
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters from accelerate.utils import find_tied_parameters
...@@ -84,6 +86,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non ...@@ -84,6 +86,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
else: else:
new_value = torch.tensor(value, device="cpu") new_value = torch.tensor(value, device="cpu")
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
new_value = new_value.T
kwargs = old_value.__dict__ kwargs = old_value.__dict__
if is_8bit: if is_8bit:
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
...@@ -122,14 +129,20 @@ def _replace_with_bnb_linear( ...@@ -122,14 +129,20 @@ def _replace_with_bnb_linear(
current_key_name = [] current_key_name = []
current_key_name.append(name) current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert: if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert` # Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights(): with init_empty_weights():
if isinstance(module, Conv1D):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features
if quantization_config.quantization_method() == "llm_int8": if quantization_config.quantization_method() == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt( model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features, in_features,
module.out_features, out_features,
module.bias is not None, module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold, threshold=quantization_config.llm_int8_threshold,
...@@ -143,14 +156,16 @@ def _replace_with_bnb_linear( ...@@ -143,14 +156,16 @@ def _replace_with_bnb_linear(
pass pass
else: else:
model._modules[name] = bnb.nn.Linear4bit( model._modules[name] = bnb.nn.Linear4bit(
module.in_features, in_features,
module.out_features, out_features,
module.bias is not None, module.bias is not None,
quantization_config.bnb_4bit_compute_dtype, quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant, compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type, quant_type=quantization_config.bnb_4bit_quant_type,
) )
has_been_replaced = True has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors # Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0: if len(list(module.children())) > 0:
...@@ -200,7 +215,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name ...@@ -200,7 +215,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
if not has_been_replaced: if not has_been_replaced:
logger.warning( logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model." "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is" " Please double check your model architecture, or submit an issue on github if you think this is"
" a bug." " a bug."
) )
......
...@@ -39,6 +39,12 @@ from transformers.testing_utils import ( ...@@ -39,6 +39,12 @@ from transformers.testing_utils import (
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -83,6 +89,7 @@ class Base4bitTest(unittest.TestCase): ...@@ -83,6 +89,7 @@ class Base4bitTest(unittest.TestCase):
EXPECTED_OUTPUTS = set() EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I") EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n") EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
MAX_NEW_TOKENS = 10 MAX_NEW_TOKENS = 10
def setUp(self): def setUp(self):
...@@ -135,7 +142,8 @@ class Bnb4BitTest(Base4bitTest): ...@@ -135,7 +142,8 @@ class Bnb4BitTest(Base4bitTest):
mem_4bit = self.model_4bit.get_memory_footprint() mem_4bit = self.model_4bit.get_memory_footprint()
self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE) self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit) linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == Params4bit)
def test_linear_are_4bit(self): def test_linear_are_4bit(self):
r""" r"""
...@@ -473,3 +481,8 @@ class Bnb4BitTestTraining(Base4bitTest): ...@@ -473,3 +481,8 @@ class Bnb4BitTestTraining(Base4bitTest):
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
self.assertTrue(module.weight.grad is None) self.assertTrue(module.weight.grad is None)
class Bnb4BitGPT2Test(Bnb4BitTest):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
...@@ -41,6 +41,12 @@ from transformers.testing_utils import ( ...@@ -41,6 +41,12 @@ from transformers.testing_utils import (
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h
if is_accelerate_available(): if is_accelerate_available():
from accelerate import PartialState from accelerate import PartialState
from accelerate.logging import get_logger from accelerate.logging import get_logger
...@@ -142,7 +148,7 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -142,7 +148,7 @@ class MixedInt8Test(BaseMixedInt8Test):
mem_8bit = self.model_8bit.get_memory_footprint() mem_8bit = self.model_8bit.get_memory_footprint()
self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE) self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) self.assertTrue(get_some_linear_layer(self.model_8bit).weight.__class__ == Int8Params)
def test_linear_are_8bit(self): def test_linear_are_8bit(self):
r""" r"""
...@@ -292,8 +298,9 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -292,8 +298,9 @@ class MixedInt8Test(BaseMixedInt8Test):
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto") model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) linear = get_some_linear_layer(model_from_saved)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB")) self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate # generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt") encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
...@@ -318,8 +325,9 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -318,8 +325,9 @@ class MixedInt8Test(BaseMixedInt8Test):
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname) model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) linear = get_some_linear_layer(model_from_saved)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB")) self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate # generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt") encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
...@@ -339,8 +347,9 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -339,8 +347,9 @@ class MixedInt8Test(BaseMixedInt8Test):
model = AutoModelForCausalLM.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) linear = get_some_linear_layer(model)
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB")) self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate # generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt") encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
...@@ -748,3 +757,13 @@ class MixedInt8TestTraining(BaseMixedInt8Test): ...@@ -748,3 +757,13 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
self.assertTrue(module.weight.grad is None) self.assertTrue(module.weight.grad is None)
class MixedInt8GPT2Test(MixedInt8Test):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the"
def test_int8_from_pretrained(self):
# TODO @younesbelkada: Test loading quantized gpt2 model from the hub.
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment