Unverified Commit 4f7806ef authored by Poedator's avatar Poedator Committed by GitHub
Browse files

[bnb] Let's make serialization of 4bit models possible (#26037)



* updated bitsandbytes.py

* rm test_raise_* from test_4bit.py

* add test_4bit_serialization.py

* modeling_utils bulk edits

* bnb_ver 0.41.3 in integrations/bitsandbytes.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* @slow reinstated
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* bnb ver 0.41.3 in  src/transformers/modeling_utils.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* rm bnb version todo in  integrations/bitsandbytes.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* moved 4b serialization tests to test_4bit

* tests upd for opt

* to torch_device
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* ruff fixes to tests

* rm redundant bnb version check in mod_utils
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* restore _hf_peft_config_loaded  modeling_utils.py::2188
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* restore _hf_peft_config_loaded  test in modeling_utils.py::2199
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fixed NOT getattr(self, "is_8bit_serializable")
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* setting model.is_4bit_serializable

* rm separate fp16_statistics arg from set_module...

* rm else branch in integrations::bnb::set_module

* bnb 4bit dtype check

* upd comment on 4bit weights

* upd tests for FP4 safe

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent e268d7e5
......@@ -21,7 +21,7 @@ if is_accelerate_available():
logger = logging.get_logger(__name__)
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
......@@ -37,8 +37,8 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
The device on which to set the tensor.
value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device).
fp16_statistics (`torch.HalfTensor`, *optional*):
The list of fp16 statistics to set on the module, used for serialization.
quantized_stats (`dict[str, Any]`, *optional*):
Dict with items for either 4-bit or 8-bit serialization
"""
# Recurse if needed
if "." in tensor_name:
......@@ -58,8 +58,7 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
is_4bit = False
is_8bit = False
prequantized_loading = quantized_stats is not None
if is_buffer or not is_bitsandbytes_available():
is_8bit = False
is_4bit = False
......@@ -74,32 +73,53 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to("cpu")
if value.dtype == torch.int8:
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
else:
new_value = torch.tensor(value, device="cpu")
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
new_value = new_value.T
kwargs = old_value.__dict__
if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
raise ValueError(
f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
)
if is_8bit:
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
if prequantized_loading:
setattr(new_value, "SCB", quantized_stats["SCB"].to(device))
elif is_4bit:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
if prequantized_loading:
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
"0.41.3"
)
if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable:
raise ValueError(
"Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Params4bit.from_prequantized(
data=new_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=device,
**kwargs,
)
else:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
setattr(module.weight, "SCB", fp16_statistics.to(device))
else:
if value is None:
......@@ -117,7 +137,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
def _replace_with_bnb_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
"""
Private method that wraps the recursion for module replacement.
......
......@@ -675,6 +675,7 @@ def _load_state_dict_into_meta_model(
is_quantized=False,
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
......@@ -776,16 +777,40 @@ def _load_state_dict_into_meta_model(
elif not is_quantized:
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
else:
fp16_statistics = None
elif param.dtype in (torch.int8, torch.uint8) and is_quantized:
# handling newly quantized weights and loaded quantized weights
# edit the param.dtype restrictions and is_quantized condition when adding new quant methods
quantized_stats = {}
if (param_name + ".quant_state.bitsandbytes__fp4" in state_dict) or (
param_name + ".quant_state.bitsandbytes__nf4" in state_dict
):
# 4bit loading. Collecting components for restoring quantized weight
# This can be expanded to make a universal call for any quantized weight loading
for k, v in state_dict.items():
if param_name + "." in k:
quantized_stats[k] = v
unexpected_keys.remove(k)
set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param, quantized_stats=quantized_stats
)
if "SCB" not in param_name:
elif param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
# 8bit loading. Could be combined with the above 4bit call.
# condition looks unreliable
fp16_statistics_key = param_name.replace("weight", "SCB")
unexpected_keys.remove(fp16_statistics_key)
set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics
model,
param_name,
param_device,
value=param,
quantized_stats={"SCB": state_dict[fp16_statistics_key]},
)
else:
# loading not quantized params in quantized model
set_module_quantized_tensor_to_device(model, param_name, param_device, value=param)
return error_msgs, offload_index, state_dict_index
......@@ -2197,15 +2222,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
and not getattr(self, "is_8bit_serializable", False)
and not _hf_peft_config_loaded
):
raise ValueError(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
raise NotImplementedError(
"You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. "
"If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
)
# If the model has adapters attached, you can save the adapters
if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded:
if (
getattr(self, "is_loaded_in_4bit", False)
and not getattr(self, "is_4bit_serializable", False)
and not _hf_peft_config_loaded
):
raise NotImplementedError(
"You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported"
"You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
"If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed."
)
if getattr(self, "_awq_is_fused", False):
......@@ -2774,8 +2803,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_safetensors = False
if is_bitsandbytes_available():
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2")
else:
is_4bit_serializable = False
is_8bit_serializable = False
if trust_remote_code is True:
......@@ -3064,10 +3095,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if (
is_8bit_serializable
and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES
and load_in_8bit
if quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES and (
(is_8bit_serializable and load_in_8bit) or (is_4bit_serializable and load_in_4bit)
):
if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES:
logger.warning(
......@@ -3077,8 +3106,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
config.quantization_config = quantization_config
elif (
is_8bit_serializable
and not load_in_8bit
(is_8bit_serializable or is_4bit_serializable)
and not (load_in_8bit or load_in_4bit)
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
):
quantization_config = config.quantization_config
......@@ -3093,8 +3122,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
load_in_8bit = quantization_config.load_in_8bit
load_in_4bit = quantization_config.load_in_4bit
if load_in_8bit:
if load_in_8bit or load_in_4bit:
if torch_dtype is None:
torch_dtype = torch.float16
if device_map is None:
......@@ -3112,12 +3142,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif (
not is_8bit_serializable
and not load_in_8bit
and not (load_in_8bit or load_in_4bit)
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
):
logger.warning(
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with "
" `bitsandbytes` version to support 4 and 8 bit serialization. Please install the latest version of `bitsandbytes` with "
" `pip install --upgrade bitsandbytes`."
)
......@@ -3525,6 +3555,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config.quantization_config = quantization_config
model.is_8bit_serializable = is_8bit_serializable
model.is_4bit_serializable = is_4bit_serializable
if load_in_8bit and torch_dtype is None:
logger.warning(
......@@ -4018,10 +4049,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
if (
state_dict[checkpoint_key].shape[-1] == 1
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
):
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
pass
else:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
if resolved_archive_file is not None:
......@@ -4130,6 +4169,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
error_msgs += new_error_msgs
else:
......@@ -4167,10 +4207,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if is_quantized:
unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem]
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]
if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info
......
......@@ -20,6 +20,7 @@ import unittest
from packaging import version
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
......@@ -29,6 +30,7 @@ from transformers import (
pipeline,
)
from transformers.testing_utils import (
is_bitsandbytes_available,
is_torch_available,
require_accelerate,
require_bitsandbytes,
......@@ -36,13 +38,21 @@ from transformers.testing_utils import (
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h
elif model.config.model_type == "opt":
try:
return model.decoder.layers[0].fc1
except AttributeError:
# for AutoModelforCausalLM
return model.model.decoder.layers[0].fc1
else:
return model.transformer.h[0].mlp.dense_4h_to_h
if is_torch_available():
......@@ -68,6 +78,10 @@ if is_torch_available():
return self.module(input, *args, **kwargs) + self.adapter(input)
if is_bitsandbytes_available():
import bitsandbytes as bnb
@require_bitsandbytes
@require_accelerate
@require_torch
......@@ -225,28 +239,6 @@ class Bnb4BitTest(Base4bitTest):
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_raise_on_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname:
self.model_4bit.save_pretrained(tmpdirname)
def test_raise_if_config_and_load_in_4bit(self):
r"""
Test that loading the model with the config and `load_in_4bit` raises an error
"""
bnb_config = BitsAndBytesConfig()
with self.assertRaises(ValueError):
_ = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=bnb_config,
load_in_4bit=True,
device_map="auto",
bnb_4bit_quant_type="nf4",
)
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
......@@ -346,8 +338,6 @@ class Bnb4BitT5Test(unittest.TestCase):
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
both cases.
"""
import bitsandbytes as bnb
from transformers import T5ForConditionalGeneration
# test with `t5-small`
......@@ -521,3 +511,140 @@ class Bnb4BitTestTraining(Base4bitTest):
class Bnb4BitGPT2Test(Bnb4BitTest):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
class BaseSerializationTest(unittest.TestCase):
model_name = "facebook/opt-125m"
input_text = "Mars colonists' favorite meals are"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
r"""
Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default.
See ExtendedSerializationTest class for more params combinations.
"""
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=quant_type,
bnb_4bit_use_double_quant=double_quant,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_0 = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=self.quantization_config,
device_map=torch_device,
)
with tempfile.TemporaryDirectory() as tmpdirname:
model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
config = AutoConfig.from_pretrained(tmpdirname)
self.assertTrue(hasattr(config, "quantization_config"))
model_1 = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
# checking quantized linear module weight
linear = get_some_linear_layer(model_1)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
self.assertTrue(hasattr(linear.weight, "quant_state"))
self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)
# checking memory footpring
self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)
# Matching all parameters and their quant_state items:
d0 = dict(model_0.named_parameters())
d1 = dict(model_1.named_parameters())
self.assertTrue(d0.keys() == d1.keys())
for k in d0.keys():
self.assertTrue(d0[k].shape == d1[k].shape)
self.assertTrue(d0[k].device.type == d1[k].device.type)
self.assertTrue(d0[k].device == d1[k].device)
self.assertTrue(d0[k].dtype == d1[k].dtype)
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
if isinstance(d0[k], bnb.nn.modules.Params4bit):
for v0, v1 in zip(
d0[k].quant_state.as_dict().values(),
d1[k].quant_state.as_dict().values(),
):
if isinstance(v0, torch.Tensor):
self.assertTrue(torch.equal(v0, v1.to(v0.device)))
else:
self.assertTrue(v0 == v1)
# comparing forward() outputs
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
out_0 = model_0(**encoded_input)
out_1 = model_1(**encoded_input)
self.assertTrue(torch.equal(out_0["logits"], out_1["logits"]))
# comparing generate() outputs
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output_sequences_0 = model_0.generate(**encoded_input, max_new_tokens=10)
output_sequences_1 = model_1.generate(**encoded_input, max_new_tokens=10)
def _decode(token):
return tokenizer.decode(token, skip_special_tokens=True)
self.assertEqual(
[_decode(x) for x in output_sequences_0],
[_decode(x) for x in output_sequences_1],
)
class ExtendedSerializationTest(BaseSerializationTest):
"""
tests more combinations of parameters
"""
def test_nf4_single_unsafe(self):
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False)
def test_nf4_single_safe(self):
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True)
def test_nf4_double_unsafe(self):
self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False)
# nf4 double safetensors quantization is tested in test_serialization() method from the parent class
def test_fp4_single_unsafe(self):
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False)
def test_fp4_single_safe(self):
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True)
def test_fp4_double_unsafe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False)
def test_fp4_double_safe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
class BloomSerializationTest(BaseSerializationTest):
"""
default BaseSerializationTest config tested with Bloom family model
"""
model_name = "bigscience/bloom-560m"
class GPTSerializationTest(BaseSerializationTest):
"""
default BaseSerializationTest config tested with GPT family model
"""
model_name = "gpt2-xl"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment