Commit 128bdd4c authored by thomwolf's avatar thomwolf
Browse files

fix tests pt/tf

parent 28a30af6
......@@ -65,7 +65,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
#####################
### PyTorch => TF 2.0
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False):
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
""" Load pytorch checkpoints in a TF 2.0 model
"""
try:
......@@ -84,7 +84,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys)
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False):
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False):
""" Load pytorch checkpoints in a TF 2.0 model
"""
pt_state_dict = pt_model.state_dict()
......@@ -92,7 +92,7 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=DUMMY_INPUTS,
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys)
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False):
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False):
""" Load pytorch state_dict in a TF 2.0 model.
"""
try:
......@@ -104,8 +104,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_I
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
raise e
if tf_inputs is not None and not isinstance(tf_inputs, tf.Tensor):
tf_inputs = tf.constant(tf_inputs)
if tf_inputs is None:
tf_inputs = tf.constant(DUMMY_INPUTS)
if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
......@@ -176,7 +176,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_I
#####################
### TF 2.0 => PyTorch
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=DUMMY_INPUTS, allow_missing_keys=False):
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
......@@ -199,9 +199,10 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
tf_model_class = getattr(pytorch_transformers, tf_model_class_name)
tf_model = tf_model_class(pt_model.config)
if tf_inputs is None:
tf_inputs = tf.constant(DUMMY_INPUTS)
if tf_inputs is not None:
if tf_inputs is not None and not isinstance(tf_inputs, tf.Tensor):
tf_inputs = tf.constant(tf_inputs)
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
tf_model.load_weights(tf_checkpoint_path, by_name=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment