Unverified Commit 9b25c164 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `Quantization`] Fix for 8bit serialization tests (#27234)

* fix for 8bit serialization

* added regression tests.

* fixup
parent c52e429b
...@@ -2110,7 +2110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2110,7 +2110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# We're going to remove aliases before saving # We're going to remove aliases before saving
ptrs = collections.defaultdict(list) ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items(): for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
ptrs[id_tensor_storage(tensor)].append(name) ptrs[id_tensor_storage(tensor)].append(name)
else:
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)
# These are all the pointers of shared tensors. # These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
......
...@@ -369,6 +369,33 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -369,6 +369,33 @@ class MixedInt8Test(BaseMixedInt8Test):
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
) )
def test_int8_serialization_regression(self):
r"""
Test whether it is possible to serialize a model in 8-bit - using not safetensors
"""
from bitsandbytes.nn import Int8Params
with tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname, safe_serialization=False)
# check that the file `quantization_config` is present
config = AutoConfig.from_pretrained(tmpdirname)
self.assertTrue(hasattr(config, "quantization_config"))
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
linear = get_some_linear_layer(model_from_saved)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
)
def test_int8_serialization_sharded(self): def test_int8_serialization_sharded(self):
r""" r"""
Test whether it is possible to serialize a model in 8-bit - sharded version. Test whether it is possible to serialize a model in 8-bit - sharded version.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment