Commit 3cb51299 authored by thomwolf's avatar thomwolf Committed by Lysandre Debut
Browse files

Fix #2109

parent 18a879f4
...@@ -143,7 +143,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -143,7 +143,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
name, transpose = convert_tf_weight_name_to_pt_weight_name(sw_name, start_prefix_to_remove=start_prefix_to_remove) name, transpose = convert_tf_weight_name_to_pt_weight_name(sw_name, start_prefix_to_remove=start_prefix_to_remove)
# Find associated numpy array in pytorch model state dict # Find associated numpy array in pytorch model state dict
assert name in pt_state_dict, "{} not found in PyTorch model".format(name) if name not in pt_state_dict:
if allow_missing_keys:
continue
raise AttributeError("{} not found in PyTorch model".format(name))
array = pt_state_dict[name].numpy() array = pt_state_dict[name].numpy()
if transpose: if transpose:
...@@ -250,6 +254,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -250,6 +254,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
all_tf_weights = set(list(tf_weights_map.keys())) all_tf_weights = set(list(tf_weights_map.keys()))
loaded_pt_weights_data_ptr = {} loaded_pt_weights_data_ptr = {}
missing_keys_pt = []
for pt_weight_name, pt_weight in current_pt_params_dict.items(): for pt_weight_name, pt_weight in current_pt_params_dict.items():
# Handle PyTorch shared weight ()not duplicated in TF 2.0 # Handle PyTorch shared weight ()not duplicated in TF 2.0
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr: if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
...@@ -258,7 +263,10 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -258,7 +263,10 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
# Find associated numpy array in pytorch model state dict # Find associated numpy array in pytorch model state dict
if pt_weight_name not in tf_weights_map: if pt_weight_name not in tf_weights_map:
raise ValueError("{} not found in TF 2.0 model".format(pt_weight_name)) if allow_missing_keys:
missing_keys_pt.append(pt_weight_name)
continue
raise AttributeError("{} not found in TF 2.0 model".format(pt_weight_name))
array, transpose = tf_weights_map[pt_weight_name] array, transpose = tf_weights_map[pt_weight_name]
...@@ -283,6 +291,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -283,6 +291,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
all_tf_weights.discard(pt_weight_name) all_tf_weights.discard(pt_weight_name)
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False) missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
missing_keys += missing_keys_pt
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from TF 2.0 model: {}".format( logger.info("Weights of {} not initialized from TF 2.0 model: {}".format(
......
...@@ -297,7 +297,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -297,7 +297,7 @@ class TFPreTrainedModel(tf.keras.Model):
if from_pt: if from_pt:
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file) return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment