Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9499a377
Unverified
Commit
9499a377
authored
Mar 06, 2020
by
Thomas Wolf
Committed by
GitHub
Mar 06, 2020
Browse files
Merge pull request #3103 from gthb/keras-serialization
Support keras JSON/HDF5 serialization of main layers
parents
c8035e11
4c91a3af
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
134 additions
and
14 deletions
+134
-14
src/transformers/modeling_tf_albert.py
src/transformers/modeling_tf_albert.py
+4
-1
src/transformers/modeling_tf_bert.py
src/transformers/modeling_tf_bert.py
+4
-1
src/transformers/modeling_tf_ctrl.py
src/transformers/modeling_tf_ctrl.py
+4
-1
src/transformers/modeling_tf_gpt2.py
src/transformers/modeling_tf_gpt2.py
+4
-0
src/transformers/modeling_tf_transfo_xl.py
src/transformers/modeling_tf_transfo_xl.py
+4
-1
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+59
-2
src/transformers/modeling_tf_xlnet.py
src/transformers/modeling_tf_xlnet.py
+11
-1
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+44
-7
No files found.
src/transformers/modeling_tf_albert.py
View file @
9499a377
...
@@ -23,7 +23,7 @@ import tensorflow as tf
...
@@ -23,7 +23,7 @@ import tensorflow as tf
from
.configuration_albert
import
AlbertConfig
from
.configuration_albert
import
AlbertConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_bert
import
ACT2FN
,
TFBertSelfAttention
from
.modeling_tf_bert
import
ACT2FN
,
TFBertSelfAttention
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
keras_serializable
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -478,7 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
...
@@ -478,7 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
return
hidden_states
return
hidden_states
@
keras_serializable
class
TFAlbertMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFAlbertMainLayer
(
tf
.
keras
.
layers
.
Layer
):
config_class
=
AlbertConfig
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
num_hidden_layers
=
config
.
num_hidden_layers
self
.
num_hidden_layers
=
config
.
num_hidden_layers
...
...
src/transformers/modeling_tf_bert.py
View file @
9499a377
...
@@ -23,7 +23,7 @@ import tensorflow as tf
...
@@ -23,7 +23,7 @@ import tensorflow as tf
from
.configuration_bert
import
BertConfig
from
.configuration_bert
import
BertConfig
from
.file_utils
import
MULTIPLE_CHOICE_DUMMY_INPUTS
,
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
MULTIPLE_CHOICE_DUMMY_INPUTS
,
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
keras_serializable
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -471,7 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer):
...
@@ -471,7 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer):
return
seq_relationship_score
return
seq_relationship_score
@
keras_serializable
class
TFBertMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFBertMainLayer
(
tf
.
keras
.
layers
.
Layer
):
config_class
=
BertConfig
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
num_hidden_layers
=
config
.
num_hidden_layers
self
.
num_hidden_layers
=
config
.
num_hidden_layers
...
...
src/transformers/modeling_tf_ctrl.py
View file @
9499a377
...
@@ -23,7 +23,7 @@ import tensorflow as tf
...
@@ -23,7 +23,7 @@ import tensorflow as tf
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_ctrl
import
CTRLConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
keras_serializable
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -164,7 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
...
@@ -164,7 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return
outputs
return
outputs
@
keras_serializable
class
TFCTRLMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFCTRLMainLayer
(
tf
.
keras
.
layers
.
Layer
):
config_class
=
CTRLConfig
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
...
...
src/transformers/modeling_tf_gpt2.py
View file @
9499a377
...
@@ -29,6 +29,7 @@ from .modeling_tf_utils import (
...
@@ -29,6 +29,7 @@ from .modeling_tf_utils import (
TFSequenceSummary
,
TFSequenceSummary
,
TFSharedEmbeddings
,
TFSharedEmbeddings
,
get_initializer
,
get_initializer
,
keras_serializable
,
shape_list
,
shape_list
,
)
)
...
@@ -196,7 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
...
@@ -196,7 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
return
outputs
# x, present, (attentions)
return
outputs
# x, present, (attentions)
@
keras_serializable
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
config_class
=
GPT2Config
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
super
().
__init__
(
*
inputs
,
**
kwargs
)
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
...
...
src/transformers/modeling_tf_transfo_xl.py
View file @
9499a377
...
@@ -24,7 +24,7 @@ import tensorflow as tf
...
@@ -24,7 +24,7 @@ import tensorflow as tf
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_transfo_xl_utilities
import
TFAdaptiveSoftmaxMask
from
.modeling_tf_transfo_xl_utilities
import
TFAdaptiveSoftmaxMask
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
keras_serializable
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -378,7 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
...
@@ -378,7 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
return
embed
return
embed
@
keras_serializable
class
TFTransfoXLMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFTransfoXLMainLayer
(
tf
.
keras
.
layers
.
Layer
):
config_class
=
TransfoXLConfig
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
...
...
src/transformers/modeling_tf_utils.py
View file @
9499a377
...
@@ -14,8 +14,7 @@
...
@@ -14,8 +14,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""TF general model utils."""
"""TF general model utils."""
import
functools
import
logging
import
logging
import
os
import
os
...
@@ -47,6 +46,64 @@ class TFModelUtilsMixin:
...
@@ -47,6 +46,64 @@ class TFModelUtilsMixin:
return
self
.
count_params
()
return
self
.
count_params
()
def
keras_serializable
(
cls
):
"""
Decorate a Keras Layer class to support Keras serialization.
This is done by:
1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
serialization time
2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
convert it to a config object for the actual layer initializer
3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does
not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model`
:param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a
`TF*MainLayer` class in this project)
:return: the same class object, with modifications for Keras deserialization.
"""
initializer
=
cls
.
__init__
config_class
=
getattr
(
cls
,
"config_class"
,
None
)
if
config_class
is
None
:
raise
AttributeError
(
"Must set `config_class` to use @keras_serializable"
)
@
functools
.
wraps
(
initializer
)
def
wrapped_init
(
self
,
*
args
,
**
kwargs
):
transformers_config
=
kwargs
.
pop
(
"transformers_config"
,
None
)
config
=
args
[
0
]
if
args
and
isinstance
(
args
[
0
],
PretrainedConfig
)
else
kwargs
.
get
(
"config"
,
None
)
if
config
is
not
None
and
transformers_config
is
not
None
:
raise
ValueError
(
"Must pass either `config` or `transformers_config`, not both"
)
elif
config
is
not
None
:
# normal layer construction, call with unchanged args (config is already in there)
initializer
(
self
,
*
args
,
**
kwargs
)
elif
transformers_config
is
not
None
:
# Keras deserialization, convert dict to config
config
=
config_class
.
from_dict
(
transformers_config
)
initializer
(
self
,
config
,
*
args
,
**
kwargs
)
else
:
raise
ValueError
(
"Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)"
)
self
.
_transformers_config
=
config
cls
.
__init__
=
wrapped_init
if
not
hasattr
(
cls
,
"get_config"
):
raise
TypeError
(
"Only use @keras_serializable on tf.keras.layers.Layer subclasses"
)
if
hasattr
(
cls
.
get_config
,
"_is_default"
):
def
get_config
(
self
):
cfg
=
super
(
cls
,
self
).
get_config
()
cfg
[
"transformers_config"
]
=
self
.
_transformers_config
.
to_dict
()
return
cfg
cls
.
get_config
=
get_config
cls
.
_keras_serializable
=
True
if
hasattr
(
tf
.
keras
.
utils
,
"register_keras_serializable"
):
cls
=
tf
.
keras
.
utils
.
register_keras_serializable
()(
cls
)
return
cls
class
TFPreTrainedModel
(
tf
.
keras
.
Model
,
TFModelUtilsMixin
):
class
TFPreTrainedModel
(
tf
.
keras
.
Model
,
TFModelUtilsMixin
):
r
""" Base class for all TF models.
r
""" Base class for all TF models.
...
...
src/transformers/modeling_tf_xlnet.py
View file @
9499a377
...
@@ -24,7 +24,14 @@ import tensorflow as tf
...
@@ -24,7 +24,14 @@ import tensorflow as tf
from
.configuration_xlnet
import
XLNetConfig
from
.configuration_xlnet
import
XLNetConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSequenceSummary
,
TFSharedEmbeddings
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
(
TFPreTrainedModel
,
TFSequenceSummary
,
TFSharedEmbeddings
,
get_initializer
,
keras_serializable
,
shape_list
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -342,7 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
...
@@ -342,7 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
return
hidden_states
return
hidden_states
@
keras_serializable
class
TFXLNetMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFXLNetMainLayer
(
tf
.
keras
.
layers
.
Layer
):
config_class
=
XLNetConfig
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
...
...
tests/test_modeling_tf_common.py
View file @
9499a377
...
@@ -19,6 +19,7 @@ import os
...
@@ -19,6 +19,7 @@ import os
import
random
import
random
import
tempfile
import
tempfile
import
unittest
import
unittest
from
importlib
import
import_module
from
transformers
import
is_tf_available
,
is_torch_available
from
transformers
import
is_tf_available
,
is_torch_available
...
@@ -89,9 +90,45 @@ class TFModelTesterMixin:
...
@@ -89,9 +90,45 @@ class TFModelTesterMixin:
model
=
model_class
.
from_pretrained
(
tmpdirname
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
after_outputs
=
model
(
inputs_dict
)
after_outputs
=
model
(
inputs_dict
)
self
.
assert_outputs_same
(
after_outputs
,
outputs
)
def
test_keras_save_load
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
tf_main_layer_classes
=
set
(
module_member
for
model_class
in
self
.
all_model_classes
for
module
in
(
import_module
(
model_class
.
__module__
),)
for
module_member_name
in
dir
(
module
)
if
module_member_name
.
endswith
(
"MainLayer"
)
for
module_member
in
(
getattr
(
module
,
module_member_name
),)
if
isinstance
(
module_member
,
type
)
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
getattr
(
module_member
,
"_keras_serializable"
,
False
)
)
for
main_layer_class
in
tf_main_layer_classes
:
main_layer
=
main_layer_class
(
config
)
symbolic_inputs
=
{
name
:
tf
.
keras
.
Input
(
tensor
.
shape
[
1
:],
dtype
=
tensor
.
dtype
)
for
name
,
tensor
in
inputs_dict
.
items
()
}
model
=
tf
.
keras
.
Model
(
symbolic_inputs
,
outputs
=
main_layer
(
symbolic_inputs
))
outputs
=
model
(
inputs_dict
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
filepath
=
os
.
path
.
join
(
tmpdirname
,
"keras_model.h5"
)
model
.
save
(
filepath
)
model
=
tf
.
keras
.
models
.
load_model
(
filepath
,
custom_objects
=
{
main_layer_class
.
__name__
:
main_layer_class
}
)
assert
isinstance
(
model
,
tf
.
keras
.
Model
)
after_outputs
=
model
(
inputs_dict
)
self
.
assert_outputs_same
(
after_outputs
,
outputs
)
def
assert_outputs_same
(
self
,
after_outputs
,
outputs
):
# Make sure we don't have nans
# Make sure we don't have nans
out_1
=
after_outputs
[
0
].
numpy
()
out_1
=
after_outputs
[
0
].
numpy
()
out_2
=
outputs
[
0
].
numpy
()
out_2
=
outputs
[
0
].
numpy
()
self
.
assertEqual
(
out_1
.
shape
,
out_2
.
shape
)
out_1
=
out_1
[
~
np
.
isnan
(
out_1
)]
out_1
=
out_1
[
~
np
.
isnan
(
out_1
)]
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment