Unverified Commit c34a525d authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Proposed fix for TF example now running on safetensors. (#23208)



* Proposed fix for TF example now running on safetensors.

* Adding more warnings and returning keys.

* Trigger CI

* Trigger CI

---------
Co-authored-by: default avatarSylvain Gugger <Sylvain.gugger@gmail.com>
parent b4d4d6fe
...@@ -297,7 +297,6 @@ class ExamplesTests(TestCasePlus): ...@@ -297,7 +297,6 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["bleu"], 30) self.assertGreaterEqual(result["bleu"], 30)
@skip("Fix me Matt")
def test_run_image_classification(self): def test_run_image_classification(self):
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
......
...@@ -246,6 +246,7 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -246,6 +246,7 @@ def load_pytorch_state_dict_in_tf2_model(
output_loading_info=False, output_loading_info=False,
_prefix=None, _prefix=None,
tf_to_pt_weight_rename=None, tf_to_pt_weight_rename=None,
ignore_mismatched_sizes=False,
): ):
"""Load a pytorch state_dict in a TF 2.0 model.""" """Load a pytorch state_dict in a TF 2.0 model."""
import tensorflow as tf import tensorflow as tf
...@@ -297,6 +298,7 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -297,6 +298,7 @@ def load_pytorch_state_dict_in_tf2_model(
weight_value_tuples = [] weight_value_tuples = []
all_pytorch_weights = set(pt_state_dict.keys()) all_pytorch_weights = set(pt_state_dict.keys())
missing_keys = [] missing_keys = []
mismatched_keys = []
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name( name, transpose = convert_tf_weight_name_to_pt_weight_name(
...@@ -319,7 +321,18 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -319,7 +321,18 @@ def load_pytorch_state_dict_in_tf2_model(
continue continue
raise AttributeError(f"{name} not found in PyTorch model") raise AttributeError(f"{name} not found in PyTorch model")
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape) try:
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
except tf.errors.InvalidArgumentError as e:
if not ignore_mismatched_sizes:
error_msg = str(e)
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise tf.errors.InvalidArgumentError(error_msg)
else:
mismatched_keys.append((name, pt_state_dict[name].shape, symbolic_weight.shape))
continue
tf_loaded_numel += tensor_size(array) tf_loaded_numel += tensor_size(array)
...@@ -367,8 +380,26 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -367,8 +380,26 @@ def load_pytorch_state_dict_in_tf2_model(
f"you can already use {tf_model.__class__.__name__} for predictions without further training." f"you can already use {tf_model.__class__.__name__} for predictions without further training."
) )
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint"
f" are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
if output_loading_info: if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
}
return tf_model, loading_info return tf_model, loading_info
return tf_model return tf_model
......
...@@ -2820,6 +2820,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2820,6 +2820,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
allow_missing_keys=True, allow_missing_keys=True,
output_loading_info=output_loading_info, output_loading_info=output_loading_info,
_prefix=load_weight_prefix, _prefix=load_weight_prefix,
ignore_mismatched_sizes=ignore_mismatched_sizes,
) )
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment