Unverified Commit fd6a0ade authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

🚨🚨🚨 [`Quantization`] Store the original dtype in the config as a private attribute 🚨🚨🚨 (#26761)



* First step

* fix

* add adjustements for gptq

* change to `_pre_quantization_dtype`

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix serialization

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 14b04b4b
...@@ -854,6 +854,9 @@ class PretrainedConfig(PushToHubMixin): ...@@ -854,6 +854,9 @@ class PretrainedConfig(PushToHubMixin):
else self.quantization_config else self.quantization_config
) )
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(serializable_config_dict) self.dict_torch_dtype_to_str(serializable_config_dict)
if "_flash_attn_2_enabled" in serializable_config_dict: if "_flash_attn_2_enabled" in serializable_config_dict:
...@@ -896,6 +899,9 @@ class PretrainedConfig(PushToHubMixin): ...@@ -896,6 +899,9 @@ class PretrainedConfig(PushToHubMixin):
else self.quantization_config else self.quantization_config
) )
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = output.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(output) self.dict_torch_dtype_to_str(output)
return output return output
......
...@@ -2178,8 +2178,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2178,8 +2178,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`." " model has already been set to the correct devices and casted to the correct `dtype`."
) )
else: elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
return super().to(*args, **kwargs) # For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
dtype_present_in_args = False
if "dtype" not in kwargs:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break
else:
dtype_present_in_args = True
if dtype_present_in_args:
raise ValueError(
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
" `dtype` by passing the correct `torch_dtype` argument."
)
return super().to(*args, **kwargs)
def half(self, *args): def half(self, *args):
# Checks if the model is quantized # Checks if the model is quantized
...@@ -3165,6 +3182,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3165,6 +3182,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hasattr(model, "quantization_method"): if hasattr(model, "quantization_method"):
model.is_quantized = True model.is_quantized = True
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
# remain a single source of truth
config._pre_quantization_dtype = torch_dtype
if isinstance(device_map, str): if isinstance(device_map, str):
special_dtypes = {} special_dtypes = {}
if load_in_8bit or load_in_4bit: if load_in_8bit or load_in_4bit:
......
...@@ -156,6 +156,14 @@ class Bnb4BitTest(Base4bitTest): ...@@ -156,6 +156,14 @@ class Bnb4BitTest(Base4bitTest):
linear = get_some_linear_layer(self.model_4bit) linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == Params4bit) self.assertTrue(linear.weight.__class__ == Params4bit)
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
"""
self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype"))
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16)
def test_linear_are_4bit(self): def test_linear_are_4bit(self):
r""" r"""
A simple test to check if the model conversion has been done correctly by checking on the A simple test to check if the model conversion has been done correctly by checking on the
......
...@@ -186,6 +186,14 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -186,6 +186,14 @@ class MixedInt8Test(BaseMixedInt8Test):
_ = config.to_json_string() _ = config.to_json_string()
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
"""
self.assertTrue(hasattr(self.model_8bit.config, "_pre_quantization_dtype"))
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
self.assertTrue(self.model_8bit.config._pre_quantization_dtype == torch.float16)
def test_memory_footprint(self): def test_memory_footprint(self):
r""" r"""
A simple test to check if the model conversion has been done correctly by checking on the A simple test to check if the model conversion has been done correctly by checking on the
......
...@@ -145,6 +145,26 @@ class GPTQTest(unittest.TestCase): ...@@ -145,6 +145,26 @@ class GPTQTest(unittest.TestCase):
self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE) self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE)
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Checks also if other models are casted correctly.
"""
# This should work
_ = self.quantized_model.to(0)
with self.assertRaises(ValueError):
# Tries with a `dtype``
self.quantized_model.to(torch.float16)
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
"""
self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype"))
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16)
def test_quantized_layers_class(self): def test_quantized_layers_class(self):
""" """
Simple test to check if the model conversion has been done correctly by checking on Simple test to check if the model conversion has been done correctly by checking on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment