"docs/vscode:/vscode.git/clone" did not exist on "6c9010ef63da1570e5a651a05bb00855b7075514"
Unverified Commit d4c834d2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix from_pretrained with corrupted state_dict (#12939)

* Fix from_pretrained with corrupted state_dict

* Adapt test

* Use better checkpoint

* Style

* Clean up
parent a28da4c4
...@@ -1409,6 +1409,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1409,6 +1409,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
add_prefix = has_prefix_module and not expects_prefix_module add_prefix = has_prefix_module and not expects_prefix_module
if remove_prefix: if remove_prefix:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys] expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
elif add_prefix: elif add_prefix:
expected_keys = [".".join([prefix, s]) for s in expected_keys] expected_keys = [".".join([prefix, s]) for s in expected_keys]
...@@ -1490,6 +1491,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1490,6 +1491,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
start_prefix = cls.base_model_prefix + "." start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module: if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix) model_to_load = getattr(model, cls.base_model_prefix)
if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError(
"The state dictionary of the model you are training to load is corrupted. Are you sure it was "
"properly saved?"
)
load(model_to_load, prefix=start_prefix) load(model_to_load, prefix=start_prefix)
......
...@@ -49,7 +49,7 @@ class BenchmarkTest(unittest.TestCase): ...@@ -49,7 +49,7 @@ class BenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.memory_inference_result) self.check_results_dict_not_empty(results.memory_inference_result)
def test_inference_no_configs_only_pretrain(self): def test_inference_no_configs_only_pretrain(self):
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english" MODEL_ID = "sgugger/tiny-distilbert-classification"
benchmark_args = PyTorchBenchmarkArguments( benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], models=[MODEL_ID],
training=False, training=False,
......
...@@ -52,7 +52,7 @@ class TFBenchmarkTest(unittest.TestCase): ...@@ -52,7 +52,7 @@ class TFBenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.memory_inference_result) self.check_results_dict_not_empty(results.memory_inference_result)
def test_inference_no_configs_only_pretrain(self): def test_inference_no_configs_only_pretrain(self):
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english" MODEL_ID = "sgugger/tiny-distilbert-classification"
benchmark_args = TensorFlowBenchmarkArguments( benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID], models=[MODEL_ID],
training=False, training=False,
......
...@@ -22,9 +22,7 @@ from .test_pipelines_common import CustomInputPipelineCommonMixin ...@@ -22,9 +22,7 @@ from .test_pipelines_common import CustomInputPipelineCommonMixin
class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "zero-shot-classification" pipeline_task = "zero-shot-classification"
small_models = [ small_models = ["sgugger/tiny-distilbert-classification"] # Models tested without the @slow decorator
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
] # Models tested without the @slow decorator
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator
valid_inputs = [ valid_inputs = [
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"}, {"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment