Commit 8df7dfd2 authored by Filip Povolny's avatar Filip Povolny
Browse files

Make dummy inputs a local variable in TFPreTrainedModel.

parent f1e4db2a
...@@ -51,7 +51,6 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -51,7 +51,6 @@ class TFPreTrainedModel(tf.keras.Model):
config_class = None config_class = None
pretrained_model_archive_map = {} pretrained_model_archive_map = {}
base_model_prefix = "" base_model_prefix = ""
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs) super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
...@@ -266,14 +265,15 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -266,14 +265,15 @@ class TFPreTrainedModel(tf.keras.Model):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file) return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
ret = model(dummy_inputs, training=False) # build the network with dummy inputs
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model.load_weights(resolved_archive_file, by_name=True) model.load_weights(resolved_archive_file, by_name=True)
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run ret = model(dummy_inputs, training=False) # Make sure restore ops are run
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment