Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
324f361e
Unverified
Commit
324f361e
authored
Sep 22, 2020
by
Julien Plu
Committed by
GitHub
Sep 22, 2020
Browse files
Fix saving TF custom models (#7291)
* Fix #7277 * Apply style * Add a full training pipeline test * Apply style
parent
cd9a0585
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
14 deletions
+77
-14
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+13
-13
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+63
-0
tests/test_modeling_tf_funnel.py
tests/test_modeling_tf_funnel.py
+1
-1
No files found.
src/transformers/modeling_tf_utils.py
View file @
324f361e
...
@@ -85,20 +85,20 @@ def keras_serializable(cls):
...
@@ -85,20 +85,20 @@ def keras_serializable(cls):
@
functools
.
wraps
(
initializer
)
@
functools
.
wraps
(
initializer
)
def
wrapped_init
(
self
,
*
args
,
**
kwargs
):
def
wrapped_init
(
self
,
*
args
,
**
kwargs
):
transformers_config
=
kwargs
.
pop
(
"transformers_config"
,
None
)
config
=
args
[
0
]
if
args
and
isinstance
(
args
[
0
],
PretrainedConfig
)
else
kwargs
.
pop
(
"config"
,
None
)
config
=
args
[
0
]
if
args
and
isinstance
(
args
[
0
],
PretrainedConfig
)
else
kwargs
.
get
(
"config"
,
None
)
if
config
is
not
None
and
transformers_config
is
not
None
:
if
isinstance
(
config
,
dict
):
raise
ValueError
(
"Must pass either `config` or `transformers_config`, not both"
)
config
=
config_class
.
from_dict
(
config
)
elif
config
is
not
None
:
initializer
(
self
,
config
,
*
args
,
**
kwargs
)
# normal layer construction, call with unchanged args (config is already in there)
elif
isinstance
(
config
,
PretrainedConfig
):
if
len
(
args
)
>
0
:
initializer
(
self
,
*
args
,
**
kwargs
)
initializer
(
self
,
*
args
,
**
kwargs
)
elif
transformers_config
is
not
None
:
else
:
# Keras deserialization, convert dict to config
config
=
config_class
.
from_dict
(
transformers_config
)
initializer
(
self
,
config
,
*
args
,
**
kwargs
)
initializer
(
self
,
config
,
*
args
,
**
kwargs
)
else
:
else
:
raise
ValueError
(
"Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)"
)
raise
ValueError
(
"Must pass either `config` (PretrainedConfig) or `config` (dict)"
)
self
.
_transformers_config
=
config
self
.
_config
=
config
self
.
_kwargs
=
kwargs
self
.
_kwargs
=
kwargs
cls
.
__init__
=
wrapped_init
cls
.
__init__
=
wrapped_init
...
@@ -109,7 +109,7 @@ def keras_serializable(cls):
...
@@ -109,7 +109,7 @@ def keras_serializable(cls):
def
get_config
(
self
):
def
get_config
(
self
):
cfg
=
super
(
cls
,
self
).
get_config
()
cfg
=
super
(
cls
,
self
).
get_config
()
cfg
[
"
transformers_
config"
]
=
self
.
_
transformers_
config
.
to_dict
()
cfg
[
"config"
]
=
self
.
_config
.
to_dict
()
cfg
.
update
(
self
.
_kwargs
)
cfg
.
update
(
self
.
_kwargs
)
return
cfg
return
cfg
...
...
tests/test_modeling_tf_common.py
View file @
324f361e
...
@@ -354,6 +354,69 @@ class TFModelTesterMixin:
...
@@ -354,6 +354,69 @@ class TFModelTesterMixin:
max_diff
=
np
.
amax
(
np
.
abs
(
tfo
-
pto
))
max_diff
=
np
.
amax
(
np
.
abs
(
tfo
-
pto
))
self
.
assertLessEqual
(
max_diff
,
4e-2
)
self
.
assertLessEqual
(
max_diff
,
4e-2
)
def
test_train_pipeline_custom_model
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
tf_main_layer_classes
=
set
(
module_member
for
model_class
in
self
.
all_model_classes
for
module
in
(
import_module
(
model_class
.
__module__
),)
for
module_member_name
in
dir
(
module
)
if
module_member_name
.
endswith
(
"MainLayer"
)
for
module_member
in
(
getattr
(
module
,
module_member_name
),)
if
isinstance
(
module_member
,
type
)
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
getattr
(
module_member
,
"_keras_serializable"
,
False
)
)
for
main_layer_class
in
tf_main_layer_classes
:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if
"T5"
in
main_layer_class
.
__name__
:
# Take the same values than in TFT5ModelTester for this shared layer
shared
=
TFSharedEmbeddings
(
self
.
model_tester
.
vocab_size
,
self
.
model_tester
.
hidden_size
,
name
=
"shared"
)
config
.
use_cache
=
False
main_layer
=
main_layer_class
(
config
,
embed_tokens
=
shared
)
del
inputs_dict
[
"use_cache"
]
else
:
main_layer
=
main_layer_class
(
config
)
symbolic_inputs
=
{
name
:
tf
.
keras
.
Input
(
tensor
.
shape
[
1
:],
dtype
=
tensor
.
dtype
)
for
name
,
tensor
in
inputs_dict
.
items
()
}
if
hasattr
(
self
.
model_tester
,
"num_labels"
):
num_labels
=
self
.
model_tester
.
num_labels
else
:
num_labels
=
2
X
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
(
inputs_dict
,
np
.
random
.
randint
(
0
,
num_labels
,
(
self
.
model_tester
.
batch_size
,
1
)))
).
batch
(
1
)
hidden_states
=
main_layer
(
symbolic_inputs
)[
0
]
outputs
=
tf
.
keras
.
layers
.
Dense
(
num_labels
,
activation
=
"softmax"
,
name
=
"outputs"
)(
hidden_states
)
model
=
tf
.
keras
.
models
.
Model
(
inputs
=
symbolic_inputs
,
outputs
=
[
outputs
])
model
.
compile
(
loss
=
"binary_crossentropy"
,
optimizer
=
"adam"
,
metrics
=
[
"acc"
])
model
.
fit
(
X
,
epochs
=
1
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
filepath
=
os
.
path
.
join
(
tmpdirname
,
"keras_model.h5"
)
model
.
save
(
filepath
)
if
"T5"
in
main_layer_class
.
__name__
:
model
=
tf
.
keras
.
models
.
load_model
(
filepath
,
custom_objects
=
{
main_layer_class
.
__name__
:
main_layer_class
,
"TFSharedEmbeddings"
:
TFSharedEmbeddings
,
},
)
else
:
model
=
tf
.
keras
.
models
.
load_model
(
filepath
,
custom_objects
=
{
main_layer_class
.
__name__
:
main_layer_class
}
)
assert
isinstance
(
model
,
tf
.
keras
.
Model
)
model
(
inputs_dict
)
def
test_compile_tf_model
(
self
):
def
test_compile_tf_model
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_funnel.py
View file @
324f361e
...
@@ -327,7 +327,7 @@ class TFFunnelModelTester:
...
@@ -327,7 +327,7 @@ class TFFunnelModelTester:
@
require_tf
@
require_tf
class
FunnelModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
class
TF
FunnelModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
all_model_classes
=
(
(
(
TFFunnelModel
,
TFFunnelModel
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment