Unverified Commit a6d05d55 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`bnb`] Fix bnb config json serialization (#24137)



* fix bnb config json serialization

* forward contrib credits from discussions

---------
Co-authored-by: default avatarAndrechang <Andrechang@users.noreply.github.com>
parent e2972dff
...@@ -784,6 +784,13 @@ class PretrainedConfig(PushToHubMixin): ...@@ -784,6 +784,13 @@ class PretrainedConfig(PushToHubMixin):
): ):
serializable_config_dict[key] = value serializable_config_dict[key] = value
if hasattr(self, "quantization_config"):
serializable_config_dict["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
self.dict_torch_dtype_to_str(serializable_config_dict) self.dict_torch_dtype_to_str(serializable_config_dict)
return serializable_config_dict return serializable_config_dict
......
...@@ -111,6 +111,19 @@ class Bnb4BitTest(Base4bitTest): ...@@ -111,6 +111,19 @@ class Bnb4BitTest(Base4bitTest):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_quantization_config_json_serialization(self):
r"""
A simple test to check if the quantization config is correctly serialized and deserialized
"""
config = self.model_4bit.config
self.assertTrue(hasattr(config, "quantization_config"))
_ = config.to_dict()
_ = config.to_diff_dict()
_ = config.to_json_string()
def test_memory_footprint(self): def test_memory_footprint(self):
r""" r"""
A simple test to check if the model conversion has been done correctly by checking on the A simple test to check if the model conversion has been done correctly by checking on the
......
...@@ -118,6 +118,19 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -118,6 +118,19 @@ class MixedInt8Test(BaseMixedInt8Test):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_quantization_config_json_serialization(self):
r"""
A simple test to check if the quantization config is correctly serialized and deserialized
"""
config = self.model_8bit.config
self.assertTrue(hasattr(config, "quantization_config"))
_ = config.to_dict()
_ = config.to_diff_dict()
_ = config.to_json_string()
def test_memory_footprint(self): def test_memory_footprint(self):
r""" r"""
A simple test to check if the model conversion has been done correctly by checking on the A simple test to check if the model conversion has been done correctly by checking on the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment