Commit 128bdd4c authored by thomwolf's avatar thomwolf
Browse files

fix tests pt/tf

parent 28a30af6
...@@ -65,7 +65,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='') ...@@ -65,7 +65,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
##################### #####################
### PyTorch => TF 2.0 ### PyTorch => TF 2.0
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False): def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
""" Load pytorch checkpoints in a TF 2.0 model """ Load pytorch checkpoints in a TF 2.0 model
""" """
try: try:
...@@ -84,7 +84,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i ...@@ -84,7 +84,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys) return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys)
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False): def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False):
""" Load pytorch checkpoints in a TF 2.0 model """ Load pytorch checkpoints in a TF 2.0 model
""" """
pt_state_dict = pt_model.state_dict() pt_state_dict = pt_model.state_dict()
...@@ -92,7 +92,7 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=DUMMY_INPUTS, ...@@ -92,7 +92,7 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=DUMMY_INPUTS,
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys) return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys)
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False): def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False):
""" Load pytorch state_dict in a TF 2.0 model. """ Load pytorch state_dict in a TF 2.0 model.
""" """
try: try:
...@@ -104,8 +104,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_I ...@@ -104,8 +104,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_I
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
raise e raise e
if tf_inputs is not None and not isinstance(tf_inputs, tf.Tensor): if tf_inputs is None:
tf_inputs = tf.constant(tf_inputs) tf_inputs = tf.constant(DUMMY_INPUTS)
if tf_inputs is not None: if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure model is built tfo = tf_model(tf_inputs, training=False) # Make sure model is built
...@@ -176,7 +176,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_I ...@@ -176,7 +176,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_I
##################### #####################
### TF 2.0 => PyTorch ### TF 2.0 => PyTorch
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False): def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model """ Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
...@@ -199,9 +199,10 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs ...@@ -199,9 +199,10 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
tf_model_class = getattr(pytorch_transformers, tf_model_class_name) tf_model_class = getattr(pytorch_transformers, tf_model_class_name)
tf_model = tf_model_class(pt_model.config) tf_model = tf_model_class(pt_model.config)
if tf_inputs is None:
tf_inputs = tf.constant(DUMMY_INPUTS)
if tf_inputs is not None: if tf_inputs is not None:
if tf_inputs is not None and not isinstance(tf_inputs, tf.Tensor):
tf_inputs = tf.constant(tf_inputs)
tfo = tf_model(tf_inputs, training=False) # Make sure model is built tfo = tf_model(tf_inputs, training=False) # Make sure model is built
tf_model.load_weights(tf_checkpoint_path, by_name=True) tf_model.load_weights(tf_checkpoint_path, by_name=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment