"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "efeab6a3f1eeaffc2cec350ffce797f209ba38f8"
Unverified Commit 7d65efec authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[bnb] Let's warn users when saving 8-bit models (#20282)



* add warning on 8-bit models

- added tests
- added wrapper

* move to a private attribute

- remove wrapper
- changed `save_pretrained` method

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix suggestions
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 0a144b8c
...@@ -1538,6 +1538,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1538,6 +1538,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
kwargs: kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False):
warnings.warn(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
" behaviors. ",
UserWarning,
)
if "save_config" in kwargs: if "save_config" in kwargs:
warnings.warn( warnings.warn(
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
...@@ -2340,6 +2348,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2340,6 +2348,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
) )
cls.is_loaded_in_8bit = load_in_8bit
# make sure token embedding weights are still tied if needed # make sure token embedding weights are still tied if needed
model.tie_weights() model.tie_weights()
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc import gc
import tempfile
import unittest import unittest
from transformers import ( from transformers import (
...@@ -107,6 +108,13 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -107,6 +108,13 @@ class MixedInt8Test(BaseMixedInt8Test):
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
class MixedInt8ModelClassesTest(BaseMixedInt8Test): class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment