Unverified Commit db2644b9 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix PyTorch/TF Auto tests (#17895)



* add loading_info
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent f717d47f
...@@ -103,7 +103,9 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", ...@@ -103,7 +103,9 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
##################### #####################
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False): def load_pytorch_checkpoint_in_tf2_model(
tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
"""Load pytorch checkpoints in a TF 2.0 model""" """Load pytorch checkpoints in a TF 2.0 model"""
try: try:
import tensorflow as tf # noqa: F401 import tensorflow as tf # noqa: F401
...@@ -122,7 +124,11 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i ...@@ -122,7 +124,11 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters") logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
return load_pytorch_weights_in_tf2_model( return load_pytorch_weights_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys tf_model,
pt_state_dict,
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info,
) )
...@@ -135,7 +141,9 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi ...@@ -135,7 +141,9 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi
) )
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False): def load_pytorch_weights_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
"""Load pytorch state_dict in a TF 2.0 model.""" """Load pytorch state_dict in a TF 2.0 model."""
try: try:
import tensorflow as tf # noqa: F401 import tensorflow as tf # noqa: F401
...@@ -281,6 +289,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -281,6 +289,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
f"you can already use {tf_model.__class__.__name__} for predictions without further training." f"you can already use {tf_model.__class__.__name__} for predictions without further training."
) )
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
return tf_model, loading_info
return tf_model return tf_model
...@@ -289,7 +301,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -289,7 +301,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
##################### #####################
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False): def load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
""" """
Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see
https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
...@@ -323,17 +337,21 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs ...@@ -323,17 +337,21 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
load_tf_weights(tf_model, tf_checkpoint_path) load_tf_weights(tf_model, tf_checkpoint_path)
return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys) return load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
)
def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False): def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False):
"""Load TF 2.0 model in a pytorch model""" """Load TF 2.0 model in a pytorch model"""
weights = tf_model.weights weights = tf_model.weights
return load_tf2_weights_in_pytorch_model(pt_model, weights, allow_missing_keys=allow_missing_keys) return load_tf2_weights_in_pytorch_model(
pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
)
def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False): def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False):
"""Load TF2.0 symbolic weights in a PyTorch model""" """Load TF2.0 symbolic weights in a PyTorch model"""
try: try:
import tensorflow as tf # noqa: F401 import tensorflow as tf # noqa: F401
...@@ -460,4 +478,8 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -460,4 +478,8 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}") logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}")
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
return pt_model, loading_info
return pt_model return pt_model
...@@ -2316,7 +2316,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2316,7 +2316,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True) return load_pytorch_checkpoint_in_tf2_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
)
# we might need to extend the variable scope for composite models # we might need to extend the variable scope for composite models
if load_weight_prefix is not None: if load_weight_prefix is not None:
......
...@@ -1831,6 +1831,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1831,6 +1831,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_sharded = False is_sharded = False
sharded_metadata = None sharded_metadata = None
# Load model # Load model
loading_info = None
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
...@@ -2086,7 +2088,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2086,7 +2088,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
try: try:
from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) model, loading_info = load_tf2_checkpoint_in_pytorch_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True
)
except ImportError: except ImportError:
logger.error( logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed." "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed."
...@@ -2139,12 +2143,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2139,12 +2143,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dispatch_model(model, device_map=device_map, offload_dir=offload_folder) dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
if output_loading_info: if output_loading_info:
loading_info = { if loading_info is None:
"missing_keys": missing_keys, loading_info = {
"unexpected_keys": unexpected_keys, "missing_keys": missing_keys,
"mismatched_keys": mismatched_keys, "unexpected_keys": unexpected_keys,
"error_msgs": error_msgs, "mismatched_keys": mismatched_keys,
} "error_msgs": error_msgs,
}
return model, loading_info return model, loading_info
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment