Commit 29bb3e4e authored by thomwolf's avatar thomwolf
Browse files

double loading ok

parent f5397ffc
...@@ -25,6 +25,7 @@ import numpy ...@@ -25,6 +25,7 @@ import numpy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''): def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
""" Convert a TF 2.0 model variable name in a pytorch model weight name. """ Convert a TF 2.0 model variable name in a pytorch model weight name.
...@@ -64,7 +65,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='') ...@@ -64,7 +65,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
##################### #####################
### PyTorch => TF 2.0 ### PyTorch => TF 2.0
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False): def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False):
""" Load pytorch checkpoints in a TF 2.0 model """ Load pytorch checkpoints in a TF 2.0 model
""" """
try: try:
...@@ -83,7 +84,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i ...@@ -83,7 +84,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys) return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys)
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False): def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False):
""" Load pytorch checkpoints in a TF 2.0 model """ Load pytorch checkpoints in a TF 2.0 model
""" """
pt_state_dict = pt_model.state_dict() pt_state_dict = pt_model.state_dict()
...@@ -91,17 +92,21 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi ...@@ -91,17 +92,21 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys) return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys)
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False): def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False):
""" Load pytorch state_dict in a TF 2.0 model. """ Load pytorch state_dict in a TF 2.0 model.
""" """
try: try:
import torch import torch
import tensorflow as tf
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
except ImportError as e: except ImportError as e:
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
raise e raise e
if tf_inputs is not None and not isinstance(tf_inputs, tf.Tensor):
tf_inputs = tf.constant(tf_inputs)
if tf_inputs is not None: if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure model is built tfo = tf_model(tf_inputs, training=False) # Make sure model is built
...@@ -171,7 +176,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -171,7 +176,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
##################### #####################
### TF 2.0 => PyTorch ### TF 2.0 => PyTorch
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False): def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model """ Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
...@@ -184,15 +189,19 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs ...@@ -184,15 +189,19 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
raise e raise e
import pytorch_transformers
tf_path = os.path.abspath(tf_checkpoint_path) tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path)) logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path))
# Instantiate and load the associated TF 2.0 model # Instantiate and load the associated TF 2.0 model
tf_model_class_name = "TF" + model_class.__name__ # Add "TF" at the beggining tf_model_class_name = "TF" + pt_model.__class__.__name__ # Add "TF" at the beggining
tf_model_class = getattr(pytorch_transformers, tf_model_class_name) tf_model_class = getattr(pytorch_transformers, tf_model_class_name)
tf_model = tf_model_class(pt_model.config) tf_model = tf_model_class(pt_model.config)
if tf_inputs is not None: if tf_inputs is not None:
if tf_inputs is not None and not isinstance(tf_inputs, tf.Tensor):
tf_inputs = tf.constant(tf_inputs)
tfo = tf_model(tf_inputs, training=False) # Make sure model is built tfo = tf_model(tf_inputs, training=False) # Make sure model is built
tf_model.load_weights(tf_checkpoint_path, by_name=True) tf_model.load_weights(tf_checkpoint_path, by_name=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment