Unverified Commit 234cfefb authored by Li-Huai (Allan) Lin's avatar Li-Huai (Allan) Lin Committed by GitHub
Browse files

Fix ignore_mismatched_sizes (#14085)

* Fix

* Style

* Name

* Fix tests

* Style

* Remove embed sizes checking

* Disable some tests

* Fix

* Apply suggestion
parent e03544a1
...@@ -1512,10 +1512,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1512,10 +1512,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if ignore_mismatched_sizes: if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys: for checkpoint_key in loaded_keys:
model_key = checkpoint_key model_key = checkpoint_key
if remove_prefix and checkpoint_key.startswith(prefix): if remove_prefix:
model_key = ".".join(checkpoint_key.split(".")[1:])
elif add_prefix:
model_key = f"{prefix}.{checkpoint_key}" model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix:
model_key = ".".join(checkpoint_key.split(".")[1:])
if ( if (
model_key in model_state_dict model_key in model_state_dict
......
...@@ -220,6 +220,7 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -220,6 +220,7 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase):
) )
test_torchscript = False test_torchscript = False
test_mismatched_shapes = False
test_resize_embeddings = False test_resize_embeddings = False
test_pruning = False test_pruning = False
......
...@@ -98,6 +98,7 @@ class ModelTesterMixin: ...@@ -98,6 +98,7 @@ class ModelTesterMixin:
test_resize_embeddings = True test_resize_embeddings = True
test_resize_position_embeddings = False test_resize_position_embeddings = False
test_head_masking = True test_head_masking = True
test_mismatched_shapes = True
test_missing_keys = True test_missing_keys = True
test_model_parallel = False test_model_parallel = False
is_encoder_decoder = False is_encoder_decoder = False
...@@ -1638,6 +1639,8 @@ class ModelTesterMixin: ...@@ -1638,6 +1639,8 @@ class ModelTesterMixin:
loss.backward() loss.backward()
def test_load_with_mismatched_shapes(self): def test_load_with_mismatched_shapes(self):
if not self.test_mismatched_shapes:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -1650,22 +1653,35 @@ class ModelTesterMixin: ...@@ -1650,22 +1653,35 @@ class ModelTesterMixin:
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
# Fails when we don't set ignore_mismatched_sizes=True # Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(RuntimeError) as e: with self.assertRaises(RuntimeError):
print(type(e))
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
with self.assertRaises(RuntimeError):
new_model_without_prefix = AutoModel.from_pretrained(tmp_dir, vocab_size=10)
logger = logging.get_logger("transformers.modeling_utils") logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
new_model = AutoModelForSequenceClassification.from_pretrained( new_model = AutoModelForSequenceClassification.from_pretrained(
tmp_dir, num_labels=42, ignore_mismatched_sizes=True tmp_dir, num_labels=42, ignore_mismatched_sizes=True
) )
self.assertIn("the shapes did not match", cl.out) self.assertIn("the shapes did not match", cl.out)
new_model.to(torch_device) new_model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class) inputs = self._prepare_for_class(inputs_dict, model_class)
logits = new_model(**inputs).logits logits = new_model(**inputs).logits
self.assertEqual(logits.shape[1], 42) self.assertEqual(logits.shape[1], 42)
with CaptureLogger(logger) as cl:
new_model_without_prefix = AutoModel.from_pretrained(
tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
input_ids = ids_tensor((2, 8), 10)
new_model_without_prefix.to(torch_device)
if self.is_encoder_decoder:
new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
else:
new_model_without_prefix(input_ids)
global_rng = random.Random() global_rng = random.Random()
......
...@@ -149,6 +149,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -149,6 +149,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
) )
test_attn_probs = False test_attn_probs = False
test_mismatched_shapes = False
def setUp(self): def setUp(self):
self.model_tester = FlaxBigBirdModelTester(self) self.model_tester = FlaxBigBirdModelTester(self)
......
...@@ -49,6 +49,7 @@ if is_flax_available(): ...@@ -49,6 +49,7 @@ if is_flax_available():
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
FlaxAutoModel,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxBertModel, FlaxBertModel,
) )
...@@ -116,6 +117,7 @@ def random_attention_mask(shape, rng=None): ...@@ -116,6 +117,7 @@ def random_attention_mask(shape, rng=None):
class FlaxModelTesterMixin: class FlaxModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
test_mismatched_shapes = True
is_encoder_decoder = False is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class): def _prepare_for_class(self, inputs_dict, model_class):
...@@ -579,6 +581,8 @@ class FlaxModelTesterMixin: ...@@ -579,6 +581,8 @@ class FlaxModelTesterMixin:
) )
def test_load_with_mismatched_shapes(self): def test_load_with_mismatched_shapes(self):
if not self.test_mismatched_shapes:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -593,6 +597,8 @@ class FlaxModelTesterMixin: ...@@ -593,6 +597,8 @@ class FlaxModelTesterMixin:
# Fails when we don't set ignore_mismatched_sizes=True # Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
with self.assertRaises(ValueError):
new_model_without_prefix = FlaxAutoModel.from_pretrained(tmp_dir, vocab_size=10)
logger = logging.get_logger("transformers.modeling_flax_utils") logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
...@@ -604,6 +610,17 @@ class FlaxModelTesterMixin: ...@@ -604,6 +610,17 @@ class FlaxModelTesterMixin:
logits = new_model(**inputs_dict)["logits"] logits = new_model(**inputs_dict)["logits"]
self.assertEqual(logits.shape[1], 42) self.assertEqual(logits.shape[1], 42)
with CaptureLogger(logger) as cl:
new_model_without_prefix = FlaxAutoModel.from_pretrained(
tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
input_ids = ids_tensor((2, 8), 10)
if self.is_encoder_decoder:
new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
else:
new_model_without_prefix(input_ids)
@require_flax @require_flax
@is_staging_test @is_staging_test
......
...@@ -260,6 +260,7 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -260,6 +260,7 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
test_mismatched_shapes = False
all_model_classes = ( all_model_classes = (
( (
......
...@@ -59,6 +59,7 @@ if is_tf_available(): ...@@ -59,6 +59,7 @@ if is_tf_available():
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig, BertConfig,
TFAutoModel,
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
TFBertModel, TFBertModel,
TFSharedEmbeddings, TFSharedEmbeddings,
...@@ -104,6 +105,7 @@ class TFModelTesterMixin: ...@@ -104,6 +105,7 @@ class TFModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
all_generative_model_classes = () all_generative_model_classes = ()
test_mismatched_shapes = True
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = True test_head_masking = True
is_encoder_decoder = False is_encoder_decoder = False
...@@ -1312,6 +1314,8 @@ class TFModelTesterMixin: ...@@ -1312,6 +1314,8 @@ class TFModelTesterMixin:
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0) self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
def test_load_with_mismatched_shapes(self): def test_load_with_mismatched_shapes(self):
if not self.test_mismatched_shapes:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -1328,6 +1332,8 @@ class TFModelTesterMixin: ...@@ -1328,6 +1332,8 @@ class TFModelTesterMixin:
# Fails when we don't set ignore_mismatched_sizes=True # Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
with self.assertRaises(ValueError):
new_model_without_prefix = TFAutoModel.from_pretrained(tmp_dir, vocab_size=10)
logger = logging.get_logger("transformers.modeling_tf_utils") logger = logging.get_logger("transformers.modeling_tf_utils")
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
...@@ -1339,6 +1345,20 @@ class TFModelTesterMixin: ...@@ -1339,6 +1345,20 @@ class TFModelTesterMixin:
logits = new_model(**inputs).logits logits = new_model(**inputs).logits
self.assertEqual(logits.shape[1], 42) self.assertEqual(logits.shape[1], 42)
with CaptureLogger(logger) as cl:
new_model_without_prefix = TFAutoModel.from_pretrained(
tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
# Although Tf models always have a prefix pointing to `MainLayer`,
# we still add this "without prefix" test to keep a consistency between tf and pt tests.
input_ids = ids_tensor((2, 8), 10)
if self.is_encoder_decoder:
new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
else:
new_model_without_prefix(input_ids)
def _generate_random_bad_tokens(self, num_bad_tokens, model): def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens # special tokens cannot be bad tokens
special_tokens = [] special_tokens = []
......
...@@ -165,6 +165,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -165,6 +165,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
test_mismatched_shapes = False
def setUp(self): def setUp(self):
self.model_tester = TFTransfoXLModelTester(self) self.model_tester = TFTransfoXLModelTester(self)
......
...@@ -180,6 +180,7 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC ...@@ -180,6 +180,7 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
test_resize_embeddings = True test_resize_embeddings = True
test_mismatched_shapes = False
def check_cutoffs_and_n_token( def check_cutoffs_and_n_token(
self, copied_cutoffs, layer, model_embed, model, model_class, resized_value, vocab_size self, copied_cutoffs, layer, model_embed, model, model_class, resized_value, vocab_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment