Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
124409d0
Commit
124409d0
authored
Nov 05, 2019
by
Filip Povolny
Browse files
Make dummy inputs a property of TFPreTrainedModel.
parent
8df7dfd2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
3 deletions
+11
-3
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+11
-3
No files found.
transformers/modeling_tf_utils.py
View file @
124409d0
...
@@ -52,6 +52,15 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -52,6 +52,15 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_archive_map
=
{}
pretrained_model_archive_map
=
{}
base_model_prefix
=
""
base_model_prefix
=
""
@
property
def
dummy_inputs
(
self
):
""" Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return
tf
.
constant
(
DUMMY_INPUTS
)
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
if
not
isinstance
(
config
,
PretrainedConfig
):
if
not
isinstance
(
config
,
PretrainedConfig
):
...
@@ -265,15 +274,14 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -265,15 +274,14 @@ class TFPreTrainedModel(tf.keras.Model):
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
return
load_pytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
)
return
load_pytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
)
dummy_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
# dummy inputs to build the network
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# build the network with dummy inputs
ret
=
model
(
dummy_inputs
,
training
=
False
)
# build the network with dummy inputs
assert
os
.
path
.
isfile
(
resolved_archive_file
),
"Error retrieving file {}"
.
format
(
resolved_archive_file
)
assert
os
.
path
.
isfile
(
resolved_archive_file
),
"Error retrieving file {}"
.
format
(
resolved_archive_file
)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model
.
load_weights
(
resolved_archive_file
,
by_name
=
True
)
model
.
load_weights
(
resolved_archive_file
,
by_name
=
True
)
ret
=
model
(
dummy_inputs
,
training
=
False
)
# Make sure restore ops are run
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# Make sure restore ops are run
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment