Unverified Commit c8b26096 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] fix 4bit `num_parameters` (#26132)

* fix 4bit `num_parameters`

* stronger check
parent 7db1ad63
...@@ -989,12 +989,33 @@ class ModuleUtilsMixin: ...@@ -989,12 +989,33 @@ class ModuleUtilsMixin:
embedding_param_names = [ embedding_param_names = [
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding) f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
] ]
non_embedding_parameters = [ total_parameters = [
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
] ]
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else: else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) total_parameters = list(self.parameters())
total_numel = []
is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
if is_loaded_in_4bit:
if is_bitsandbytes_available():
import bitsandbytes as bnb
else:
raise ValueError(
"bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
" make sure to install bitsandbytes with `pip install bitsandbytes`."
)
for param in total_parameters:
if param.requires_grad or not only_trainable:
# For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
# used for the 4bit quantization (uint8 tensors are stored)
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
total_numel.append(param.numel() * 2)
else:
total_numel.append(param.numel())
return sum(total_numel)
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int: def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
""" """
......
...@@ -118,6 +118,17 @@ class Bnb4BitTest(Base4bitTest): ...@@ -118,6 +118,17 @@ class Bnb4BitTest(Base4bitTest):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_quantization_num_parameters(self):
r"""
Test if the number of returned parameters is correct
See: https://github.com/huggingface/transformers/issues/25978
"""
num_params_4bit = self.model_4bit.num_parameters()
num_params_fp16 = self.model_fp16.num_parameters()
self.assertEqual(num_params_4bit, num_params_fp16)
def test_quantization_config_json_serialization(self): def test_quantization_config_json_serialization(self):
r""" r"""
A simple test to check if the quantization config is correctly serialized and deserialized A simple test to check if the quantization config is correctly serialized and deserialized
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment