Unverified Commit 7d65efec authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[bnb] Let's warn users when saving 8-bit models (#20282)



* add warning on 8-bit models

- added tests
- added wrapper

* move to a private attribute

- remove wrapper
- changed `save_pretrained` method

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix suggestions
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 0a144b8c
......@@ -1538,6 +1538,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False):
warnings.warn(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
" behaviors. ",
UserWarning,
)
if "save_config" in kwargs:
warnings.warn(
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
......@@ -2340,6 +2348,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
load_in_8bit=load_in_8bit,
)
cls.is_loaded_in_8bit = load_in_8bit
# make sure token embedding weights are still tied if needed
model.tie_weights()
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import (
......@@ -107,6 +108,13 @@ class MixedInt8Test(BaseMixedInt8Test):
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment