Unverified Commit db2644b9 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix PyTorch/TF Auto tests (#17895)



* add loading_info
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent f717d47f
......@@ -103,7 +103,9 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
#####################
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
def load_pytorch_checkpoint_in_tf2_model(
tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
"""Load pytorch checkpoints in a TF 2.0 model"""
try:
import tensorflow as tf # noqa: F401
......@@ -122,7 +124,11 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
return load_pytorch_weights_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys
tf_model,
pt_state_dict,
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info,
)
......@@ -135,7 +141,9 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi
)
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False):
def load_pytorch_weights_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
"""Load pytorch state_dict in a TF 2.0 model."""
try:
import tensorflow as tf # noqa: F401
......@@ -281,6 +289,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
)
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
return tf_model, loading_info
return tf_model
......@@ -289,7 +301,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
#####################
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
def load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
"""
Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see
https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
......@@ -323,17 +337,21 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
load_tf_weights(tf_model, tf_checkpoint_path)
return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys)
return load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
)
def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False):
def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False):
"""Load TF 2.0 model in a pytorch model"""
weights = tf_model.weights
return load_tf2_weights_in_pytorch_model(pt_model, weights, allow_missing_keys=allow_missing_keys)
return load_tf2_weights_in_pytorch_model(
pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
)
def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False):
def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False):
"""Load TF2.0 symbolic weights in a PyTorch model"""
try:
import tensorflow as tf # noqa: F401
......@@ -460,4 +478,8 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}")
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
return pt_model, loading_info
return pt_model
......@@ -2316,7 +2316,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
# Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
return load_pytorch_checkpoint_in_tf2_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
)
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
......
......@@ -1831,6 +1831,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_sharded = False
sharded_metadata = None
# Load model
loading_info = None
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
......@@ -2086,7 +2088,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
try:
from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
model, loading_info = load_tf2_checkpoint_in_pytorch_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True
)
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed."
......@@ -2139,6 +2143,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
if output_loading_info:
if loading_info is None:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment