Commit 124409d0 authored by Filip Povolny's avatar Filip Povolny
Browse files

Make dummy inputs a property of TFPreTrainedModel.

parent 8df7dfd2
...@@ -52,6 +52,15 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -52,6 +52,15 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_archive_map = {} pretrained_model_archive_map = {}
base_model_prefix = "" base_model_prefix = ""
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return tf.constant(DUMMY_INPUTS)
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs) super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
...@@ -265,15 +274,14 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -265,15 +274,14 @@ class TFPreTrainedModel(tf.keras.Model):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file) return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
ret = model(dummy_inputs, training=False) # build the network with dummy inputs
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model.load_weights(resolved_archive_file, by_name=True) model.load_weights(resolved_archive_file, by_name=True)
ret = model(dummy_inputs, training=False) # Make sure restore ops are run ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment