Unverified Commit 4f7806ef authored by Poedator's avatar Poedator Committed by GitHub
Browse files

[bnb] Let's make serialization of 4bit models possible (#26037)



* updated bitsandbytes.py

* rm test_raise_* from test_4bit.py

* add test_4bit_serialization.py

* modeling_utils bulk edits

* bnb_ver 0.41.3 in integrations/bitsandbytes.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* @slow reinstated
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* bnb ver 0.41.3 in  src/transformers/modeling_utils.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* rm bnb version todo in  integrations/bitsandbytes.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* moved 4b serialization tests to test_4bit

* tests upd for opt

* to torch_device
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* ruff fixes to tests

* rm redundant bnb version check in mod_utils
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* restore _hf_peft_config_loaded  modeling_utils.py::2188
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* restore _hf_peft_config_loaded  test in modeling_utils.py::2199
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fixed NOT getattr(self, "is_8bit_serializable")
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* setting model.is_4bit_serializable

* rm separate fp16_statistics arg from set_module...

* rm else branch in integrations::bnb::set_module

* bnb 4bit dtype check

* upd comment on 4bit weights

* upd tests for FP4 safe

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent e268d7e5
...@@ -21,7 +21,7 @@ if is_accelerate_available(): ...@@ -21,7 +21,7 @@ if is_accelerate_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None): def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None):
""" """
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
...@@ -37,8 +37,8 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non ...@@ -37,8 +37,8 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
The device on which to set the tensor. The device on which to set the tensor.
value (`torch.Tensor`, *optional*): value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device). The value of the tensor (useful when going from the meta device to any other device).
fp16_statistics (`torch.HalfTensor`, *optional*): quantized_stats (`dict[str, Any]`, *optional*):
The list of fp16 statistics to set on the module, used for serialization. Dict with items for either 4-bit or 8-bit serialization
""" """
# Recurse if needed # Recurse if needed
if "." in tensor_name: if "." in tensor_name:
...@@ -58,8 +58,7 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non ...@@ -58,8 +58,7 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
is_4bit = False prequantized_loading = quantized_stats is not None
is_8bit = False
if is_buffer or not is_bitsandbytes_available(): if is_buffer or not is_bitsandbytes_available():
is_8bit = False is_8bit = False
is_4bit = False is_4bit = False
...@@ -74,32 +73,53 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non ...@@ -74,32 +73,53 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
new_value = old_value.to(device) new_value = old_value.to(device)
elif isinstance(value, torch.Tensor): elif isinstance(value, torch.Tensor):
new_value = value.to("cpu") new_value = value.to("cpu")
if value.dtype == torch.int8:
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
else: else:
new_value = torch.tensor(value, device="cpu") new_value = torch.tensor(value, device="cpu")
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization. # Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading. # Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None: if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
new_value = new_value.T new_value = new_value.T
kwargs = old_value.__dict__ kwargs = old_value.__dict__
if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
raise ValueError(
f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
)
if is_8bit: if is_8bit:
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
if prequantized_loading:
setattr(new_value, "SCB", quantized_stats["SCB"].to(device))
elif is_4bit: elif is_4bit:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) if prequantized_loading:
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
"0.41.3"
)
if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable:
raise ValueError(
"Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Params4bit.from_prequantized(
data=new_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=device,
**kwargs,
)
else:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
module._parameters[tensor_name] = new_value module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
setattr(module.weight, "SCB", fp16_statistics.to(device))
else: else:
if value is None: if value is None:
...@@ -117,7 +137,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non ...@@ -117,7 +137,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
def _replace_with_bnb_linear( def _replace_with_bnb_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
): ):
""" """
Private method that wraps the recursion for module replacement. Private method that wraps the recursion for module replacement.
......
...@@ -675,6 +675,7 @@ def _load_state_dict_into_meta_model( ...@@ -675,6 +675,7 @@ def _load_state_dict_into_meta_model(
is_quantized=False, is_quantized=False,
is_safetensors=False, is_safetensors=False,
keep_in_fp32_modules=None, keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
): ):
""" """
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
...@@ -776,16 +777,40 @@ def _load_state_dict_into_meta_model( ...@@ -776,16 +777,40 @@ def _load_state_dict_into_meta_model(
elif not is_quantized: elif not is_quantized:
# For backward compatibility with older versions of `accelerate` # For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else: elif param.dtype in (torch.int8, torch.uint8) and is_quantized:
if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys(): # handling newly quantized weights and loaded quantized weights
fp16_statistics = state_dict[param_name.replace("weight", "SCB")] # edit the param.dtype restrictions and is_quantized condition when adding new quant methods
else: quantized_stats = {}
fp16_statistics = None
if (param_name + ".quant_state.bitsandbytes__fp4" in state_dict) or (
param_name + ".quant_state.bitsandbytes__nf4" in state_dict
):
# 4bit loading. Collecting components for restoring quantized weight
# This can be expanded to make a universal call for any quantized weight loading
for k, v in state_dict.items():
if param_name + "." in k:
quantized_stats[k] = v
unexpected_keys.remove(k)
set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param, quantized_stats=quantized_stats
)
if "SCB" not in param_name: elif param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
# 8bit loading. Could be combined with the above 4bit call.
# condition looks unreliable
fp16_statistics_key = param_name.replace("weight", "SCB")
unexpected_keys.remove(fp16_statistics_key)
set_module_quantized_tensor_to_device( set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics model,
param_name,
param_device,
value=param,
quantized_stats={"SCB": state_dict[fp16_statistics_key]},
) )
else:
# loading not quantized params in quantized model
set_module_quantized_tensor_to_device(model, param_name, param_device, value=param)
return error_msgs, offload_index, state_dict_index return error_msgs, offload_index, state_dict_index
...@@ -2197,15 +2222,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2197,15 +2222,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
and not getattr(self, "is_8bit_serializable", False) and not getattr(self, "is_8bit_serializable", False)
and not _hf_peft_config_loaded and not _hf_peft_config_loaded
): ):
raise ValueError( raise NotImplementedError(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected" "You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. "
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed." "If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
) )
# If the model has adapters attached, you can save the adapters if (
if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded: getattr(self, "is_loaded_in_4bit", False)
and not getattr(self, "is_4bit_serializable", False)
and not _hf_peft_config_loaded
):
raise NotImplementedError( raise NotImplementedError(
"You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported" "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
"If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed."
) )
if getattr(self, "_awq_is_fused", False): if getattr(self, "_awq_is_fused", False):
...@@ -2774,8 +2803,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2774,8 +2803,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_safetensors = False use_safetensors = False
if is_bitsandbytes_available(): if is_bitsandbytes_available():
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2") is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2")
else: else:
is_4bit_serializable = False
is_8bit_serializable = False is_8bit_serializable = False
if trust_remote_code is True: if trust_remote_code is True:
...@@ -3064,10 +3095,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3064,10 +3095,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if low_cpu_mem_usage is None: if low_cpu_mem_usage is None:
low_cpu_mem_usage = True low_cpu_mem_usage = True
if ( if quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES and (
is_8bit_serializable (is_8bit_serializable and load_in_8bit) or (is_4bit_serializable and load_in_4bit)
and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES
and load_in_8bit
): ):
if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES: if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES:
logger.warning( logger.warning(
...@@ -3077,8 +3106,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3077,8 +3106,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
config.quantization_config = quantization_config config.quantization_config = quantization_config
elif ( elif (
is_8bit_serializable (is_8bit_serializable or is_4bit_serializable)
and not load_in_8bit and not (load_in_8bit or load_in_4bit)
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
): ):
quantization_config = config.quantization_config quantization_config = config.quantization_config
...@@ -3093,8 +3122,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3093,8 +3122,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
load_in_8bit = quantization_config.load_in_8bit load_in_8bit = quantization_config.load_in_8bit
load_in_4bit = quantization_config.load_in_4bit
if load_in_8bit: if load_in_8bit or load_in_4bit:
if torch_dtype is None: if torch_dtype is None:
torch_dtype = torch.float16 torch_dtype = torch.float16
if device_map is None: if device_map is None:
...@@ -3112,12 +3142,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3112,12 +3142,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif ( elif (
not is_8bit_serializable not is_8bit_serializable
and not load_in_8bit and not (load_in_8bit or load_in_4bit)
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
): ):
logger.warning( logger.warning(
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct" "Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with " " `bitsandbytes` version to support 4 and 8 bit serialization. Please install the latest version of `bitsandbytes` with "
" `pip install --upgrade bitsandbytes`." " `pip install --upgrade bitsandbytes`."
) )
...@@ -3525,6 +3555,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3525,6 +3555,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config.quantization_config = quantization_config config.quantization_config = quantization_config
model.is_8bit_serializable = is_8bit_serializable model.is_8bit_serializable = is_8bit_serializable
model.is_4bit_serializable = is_4bit_serializable
if load_in_8bit and torch_dtype is None: if load_in_8bit and torch_dtype is None:
logger.warning( logger.warning(
...@@ -4018,10 +4049,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -4018,10 +4049,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_key in model_state_dict model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
): ):
mismatched_keys.append( if (
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) state_dict[checkpoint_key].shape[-1] == 1
) and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
del state_dict[checkpoint_key] ):
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
pass
else:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys return mismatched_keys
if resolved_archive_file is not None: if resolved_archive_file is not None:
...@@ -4130,6 +4169,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -4130,6 +4169,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_quantized=is_quantized, is_quantized=is_quantized,
is_safetensors=is_safetensors, is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
) )
error_msgs += new_error_msgs error_msgs += new_error_msgs
else: else:
...@@ -4167,10 +4207,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -4167,10 +4207,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if is_quantized:
unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem]
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info warner = logger.warning if model.__class__.__name__ in archs else logger.info
......
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
from packaging import version from packaging import version
from transformers import ( from transformers import (
AutoConfig,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
...@@ -29,6 +30,7 @@ from transformers import ( ...@@ -29,6 +30,7 @@ from transformers import (
pipeline, pipeline,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
is_bitsandbytes_available,
is_torch_available, is_torch_available,
require_accelerate, require_accelerate,
require_bitsandbytes, require_bitsandbytes,
...@@ -36,13 +38,21 @@ from transformers.testing_utils import ( ...@@ -36,13 +38,21 @@ from transformers.testing_utils import (
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
slow, slow,
torch_device,
) )
def get_some_linear_layer(model): def get_some_linear_layer(model):
if model.config.model_type == "gpt2": if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h elif model.config.model_type == "opt":
try:
return model.decoder.layers[0].fc1
except AttributeError:
# for AutoModelforCausalLM
return model.model.decoder.layers[0].fc1
else:
return model.transformer.h[0].mlp.dense_4h_to_h
if is_torch_available(): if is_torch_available():
...@@ -68,6 +78,10 @@ if is_torch_available(): ...@@ -68,6 +78,10 @@ if is_torch_available():
return self.module(input, *args, **kwargs) + self.adapter(input) return self.module(input, *args, **kwargs) + self.adapter(input)
if is_bitsandbytes_available():
import bitsandbytes as bnb
@require_bitsandbytes @require_bitsandbytes
@require_accelerate @require_accelerate
@require_torch @require_torch
...@@ -225,28 +239,6 @@ class Bnb4BitTest(Base4bitTest): ...@@ -225,28 +239,6 @@ class Bnb4BitTest(Base4bitTest):
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_raise_on_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname:
self.model_4bit.save_pretrained(tmpdirname)
def test_raise_if_config_and_load_in_4bit(self):
r"""
Test that loading the model with the config and `load_in_4bit` raises an error
"""
bnb_config = BitsAndBytesConfig()
with self.assertRaises(ValueError):
_ = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=bnb_config,
load_in_4bit=True,
device_map="auto",
bnb_4bit_quant_type="nf4",
)
def test_device_and_dtype_assignment(self): def test_device_and_dtype_assignment(self):
r""" r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
...@@ -346,8 +338,6 @@ class Bnb4BitT5Test(unittest.TestCase): ...@@ -346,8 +338,6 @@ class Bnb4BitT5Test(unittest.TestCase):
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
both cases. both cases.
""" """
import bitsandbytes as bnb
from transformers import T5ForConditionalGeneration from transformers import T5ForConditionalGeneration
# test with `t5-small` # test with `t5-small`
...@@ -521,3 +511,140 @@ class Bnb4BitTestTraining(Base4bitTest): ...@@ -521,3 +511,140 @@ class Bnb4BitTestTraining(Base4bitTest):
class Bnb4BitGPT2Test(Bnb4BitTest): class Bnb4BitGPT2Test(Bnb4BitTest):
model_name = "gpt2-xl" model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187 EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
class BaseSerializationTest(unittest.TestCase):
model_name = "facebook/opt-125m"
input_text = "Mars colonists' favorite meals are"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
r"""
Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default.
See ExtendedSerializationTest class for more params combinations.
"""
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=quant_type,
bnb_4bit_use_double_quant=double_quant,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_0 = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=self.quantization_config,
device_map=torch_device,
)
with tempfile.TemporaryDirectory() as tmpdirname:
model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
config = AutoConfig.from_pretrained(tmpdirname)
self.assertTrue(hasattr(config, "quantization_config"))
model_1 = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
# checking quantized linear module weight
linear = get_some_linear_layer(model_1)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
self.assertTrue(hasattr(linear.weight, "quant_state"))
self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)
# checking memory footpring
self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)
# Matching all parameters and their quant_state items:
d0 = dict(model_0.named_parameters())
d1 = dict(model_1.named_parameters())
self.assertTrue(d0.keys() == d1.keys())
for k in d0.keys():
self.assertTrue(d0[k].shape == d1[k].shape)
self.assertTrue(d0[k].device.type == d1[k].device.type)
self.assertTrue(d0[k].device == d1[k].device)
self.assertTrue(d0[k].dtype == d1[k].dtype)
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
if isinstance(d0[k], bnb.nn.modules.Params4bit):
for v0, v1 in zip(
d0[k].quant_state.as_dict().values(),
d1[k].quant_state.as_dict().values(),
):
if isinstance(v0, torch.Tensor):
self.assertTrue(torch.equal(v0, v1.to(v0.device)))
else:
self.assertTrue(v0 == v1)
# comparing forward() outputs
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
out_0 = model_0(**encoded_input)
out_1 = model_1(**encoded_input)
self.assertTrue(torch.equal(out_0["logits"], out_1["logits"]))
# comparing generate() outputs
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output_sequences_0 = model_0.generate(**encoded_input, max_new_tokens=10)
output_sequences_1 = model_1.generate(**encoded_input, max_new_tokens=10)
def _decode(token):
return tokenizer.decode(token, skip_special_tokens=True)
self.assertEqual(
[_decode(x) for x in output_sequences_0],
[_decode(x) for x in output_sequences_1],
)
class ExtendedSerializationTest(BaseSerializationTest):
"""
tests more combinations of parameters
"""
def test_nf4_single_unsafe(self):
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False)
def test_nf4_single_safe(self):
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True)
def test_nf4_double_unsafe(self):
self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False)
# nf4 double safetensors quantization is tested in test_serialization() method from the parent class
def test_fp4_single_unsafe(self):
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False)
def test_fp4_single_safe(self):
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True)
def test_fp4_double_unsafe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False)
def test_fp4_double_safe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
class BloomSerializationTest(BaseSerializationTest):
"""
default BaseSerializationTest config tested with Bloom family model
"""
model_name = "bigscience/bloom-560m"
class GPTSerializationTest(BaseSerializationTest):
"""
default BaseSerializationTest config tested with GPT family model
"""
model_name = "gpt2-xl"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment