Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8df7dfd2
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "24e67fbf75a4dc62d4ad2f8b99e84a3ecbf35a6e"
Commit
8df7dfd2
authored
Nov 05, 2019
by
Filip Povolny
Browse files
Make dummy inputs a local variable in TFPreTrainedModel.
parent
f1e4db2a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+3
-3
No files found.
transformers/modeling_tf_utils.py
View file @
8df7dfd2
...
@@ -51,7 +51,6 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -51,7 +51,6 @@ class TFPreTrainedModel(tf.keras.Model):
config_class
=
None
config_class
=
None
pretrained_model_archive_map
=
{}
pretrained_model_archive_map
=
{}
base_model_prefix
=
""
base_model_prefix
=
""
dummy_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
# dummy inputs to build the network
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
...
@@ -266,14 +265,15 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -266,14 +265,15 @@ class TFPreTrainedModel(tf.keras.Model):
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
return
load_pytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
)
return
load_pytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
)
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# build the network with dummy inputs
dummy_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
# dummy inputs to build the network
ret
=
model
(
dummy_inputs
,
training
=
False
)
# build the network with dummy inputs
assert
os
.
path
.
isfile
(
resolved_archive_file
),
"Error retrieving file {}"
.
format
(
resolved_archive_file
)
assert
os
.
path
.
isfile
(
resolved_archive_file
),
"Error retrieving file {}"
.
format
(
resolved_archive_file
)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model
.
load_weights
(
resolved_archive_file
,
by_name
=
True
)
model
.
load_weights
(
resolved_archive_file
,
by_name
=
True
)
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# Make sure restore ops are run
ret
=
model
(
dummy_inputs
,
training
=
False
)
# Make sure restore ops are run
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment