Unverified Commit 57516c0c authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[multiple models] skip saving/loading deterministic state_dict keys (#7878)

* make the save_load special key tests common

* handle mbart

* cleaner solution

* fix

* move test_save_load_missing_keys back into fstm for now

* restore

* style

* add marian

* add pegasus

* blenderbot

* revert - no static embed
parent 006a1648
...@@ -47,6 +47,14 @@ class MarianMTModel(BartForConditionalGeneration): ...@@ -47,6 +47,14 @@ class MarianMTModel(BartForConditionalGeneration):
""" """
config_class = MarianConfig config_class = MarianConfig
authorized_missing_keys = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
keys_to_never_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
def adjust_logits_during_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
......
...@@ -29,3 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration): ...@@ -29,3 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
""" """
model_type = "mbart" model_type = "mbart"
config_class = MBartConfig config_class = MBartConfig
authorized_missing_keys = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
keys_to_never_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
...@@ -50,6 +50,10 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration): ...@@ -50,6 +50,10 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration):
r"final_logits_bias", r"final_logits_bias",
r"encoder\.version", r"encoder\.version",
r"decoder\.version", r"decoder\.version",
r"model.encoder.embed_positions", "model.encoder.embed_positions",
"model.decoder.embed_positions", "model.decoder.embed_positions",
] ]
keys_to_never_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
...@@ -22,6 +22,7 @@ import unittest ...@@ -22,6 +22,7 @@ import unittest
from typing import List, Tuple from typing import List, Tuple
from transformers import is_torch_available from transformers import is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device
...@@ -129,6 +130,27 @@ class ModelTesterMixin: ...@@ -129,6 +130,27 @@ class ModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5) self.assertLessEqual(max_diff, 1e-5)
def test_save_load_keys_to_never_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
keys_to_never_save = getattr(model, "keys_to_never_save", None)
if keys_to_never_save is None:
continue
# check the keys are in the original state_dict
for k in keys_to_never_save:
self.assertIn(k, model.state_dict())
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
state_dict_saved = torch.load(output_model_file)
for k in keys_to_never_save:
self.assertNotIn(k, state_dict_saved)
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import tempfile import tempfile
import unittest import unittest
...@@ -21,7 +20,7 @@ import timeout_decorator # noqa ...@@ -21,7 +20,7 @@ import timeout_decorator # noqa
from parameterized import parameterized from parameterized import parameterized
from transformers import is_torch_available from transformers import is_torch_available
from transformers.file_utils import WEIGHTS_NAME, cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -203,8 +202,9 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -203,8 +202,9 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
)[0] )[0]
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask) _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
def test_save_load_strict(self): def test_save_load_missing_keys(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs() config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
...@@ -213,27 +213,6 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -213,27 +213,6 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], []) self.assertEqual(info["missing_keys"], [])
def test_save_load_no_save_keys(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
state_dict_no_save_keys = getattr(model, "state_dict_no_save_keys", None)
if state_dict_no_save_keys is None:
continue
# check the keys are in the original state_dict
for k in state_dict_no_save_keys:
self.assertIn(k, model.state_dict())
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
state_dict_saved = torch.load(output_model_file)
for k in state_dict_no_save_keys:
self.assertNotIn(k, state_dict_saved)
@unittest.skip("can't be implemented for FSMT due to dual vocab.") @unittest.skip("can't be implemented for FSMT due to dual vocab.")
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
pass pass
......
...@@ -21,6 +21,8 @@ from transformers.file_utils import cached_property ...@@ -21,6 +21,8 @@ from transformers.file_utils import cached_property
from transformers.hf_api import HfApi from transformers.hf_api import HfApi
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_modeling_common import ModelTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -35,6 +37,37 @@ if is_torch_available(): ...@@ -35,6 +37,37 @@ if is_torch_available():
from transformers.pipelines import TranslationPipeline from transformers.pipelines import TranslationPipeline
@require_torch
class ModelTester:
def __init__(self, parent):
self.config = MarianConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
return_dict=True,
)
def prepare_config_and_inputs_for_common(self):
return self.config, {}
@require_torch
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MarianMTModel,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
def setUp(self):
self.model_tester = ModelTester(self)
class ModelManagementTests(unittest.TestCase): class ModelManagementTests(unittest.TestCase):
@slow @slow
def test_model_names(self): def test_model_names(self):
......
...@@ -5,6 +5,7 @@ from transformers.file_utils import cached_property ...@@ -5,6 +5,7 @@ from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_modeling_bart import TOLERANCE, _long_tensor, assert_tensors_close from .test_modeling_bart import TOLERANCE, _long_tensor, assert_tensors_close
from .test_modeling_common import ModelTesterMixin
if is_torch_available(): if is_torch_available():
...@@ -23,6 +24,37 @@ EN_CODE = 250004 ...@@ -23,6 +24,37 @@ EN_CODE = 250004
RO_CODE = 250020 RO_CODE = 250020
@require_torch
class ModelTester:
def __init__(self, parent):
self.config = MBartConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
return_dict=True,
)
def prepare_config_and_inputs_for_common(self):
return self.config, {}
@require_torch
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
def setUp(self):
self.model_tester = ModelTester(self)
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
......
...@@ -7,17 +7,49 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers ...@@ -7,17 +7,49 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
from transformers.utils.logging import ERROR, set_verbosity from transformers.utils.logging import ERROR, set_verbosity
from .test_modeling_bart import PGE_ARTICLE from .test_modeling_bart import PGE_ARTICLE
from .test_modeling_common import ModelTesterMixin
from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest
if is_torch_available(): if is_torch_available():
from transformers import AutoModelForSeq2SeqLM from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration
XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """ XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """
set_verbosity(ERROR) set_verbosity(ERROR)
@require_torch
class ModelTester:
def __init__(self, parent):
self.config = PegasusConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
return_dict=True,
)
def prepare_config_and_inputs_for_common(self):
return self.config, {}
@require_torch
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
def setUp(self):
self.model_tester = ModelTester(self)
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment