Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ba281707
"vscode:/vscode.git/clone" did not exist on "20fc18fbda3669c2f4a3510e0705b2acd54bff07"
Commit
ba281707
authored
Mar 03, 2020
by
Gunnlaugur Thor Briem
Browse files
Support keras JSON/HDF5 serialization of main layers
Fixes #3101
parent
a088d75e
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
59 additions
and
26 deletions
+59
-26
src/transformers/modeling_tf_albert.py
src/transformers/modeling_tf_albert.py
+2
-2
src/transformers/modeling_tf_bert.py
src/transformers/modeling_tf_bert.py
+3
-3
src/transformers/modeling_tf_ctrl.py
src/transformers/modeling_tf_ctrl.py
+3
-3
src/transformers/modeling_tf_distilbert.py
src/transformers/modeling_tf_distilbert.py
+3
-3
src/transformers/modeling_tf_gpt2.py
src/transformers/modeling_tf_gpt2.py
+3
-2
src/transformers/modeling_tf_openai.py
src/transformers/modeling_tf_openai.py
+2
-1
src/transformers/modeling_tf_t5.py
src/transformers/modeling_tf_t5.py
+3
-3
src/transformers/modeling_tf_transfo_xl.py
src/transformers/modeling_tf_transfo_xl.py
+3
-3
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+17
-0
src/transformers/modeling_tf_xlm.py
src/transformers/modeling_tf_xlm.py
+10
-3
src/transformers/modeling_tf_xlnet.py
src/transformers/modeling_tf_xlnet.py
+10
-3
No files found.
src/transformers/modeling_tf_albert.py
View file @
ba281707
...
...
@@ -23,7 +23,7 @@ import tensorflow as tf
from
.configuration_albert
import
AlbertConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_bert
import
ACT2FN
,
TFBertSelfAttention
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
TFMainLayer
,
TFPreTrainedModel
,
get_initializer
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -478,7 +478,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
return
hidden_states
class
TFAlbertMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFAlbertMainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
config
,
**
kwargs
)
self
.
num_hidden_layers
=
config
.
num_hidden_layers
...
...
src/transformers/modeling_tf_bert.py
View file @
ba281707
...
...
@@ -23,7 +23,7 @@ import tensorflow as tf
from
.configuration_bert
import
BertConfig
from
.file_utils
import
MULTIPLE_CHOICE_DUMMY_INPUTS
,
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
TFMainLayer
,
TFPreTrainedModel
,
get_initializer
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -471,9 +471,9 @@ class TFBertNSPHead(tf.keras.layers.Layer):
return
seq_relationship_score
class
TFBertMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFBertMainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
config
,
**
kwargs
)
self
.
num_hidden_layers
=
config
.
num_hidden_layers
self
.
embeddings
=
TFBertEmbeddings
(
config
,
name
=
"embeddings"
)
...
...
src/transformers/modeling_tf_ctrl.py
View file @
ba281707
...
...
@@ -23,7 +23,7 @@ import tensorflow as tf
from
.configuration_ctrl
import
CTRLConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
from
.modeling_tf_utils
import
TFMainLayer
,
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -164,9 +164,9 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return
outputs
class
TFCTRLMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFCTRLMainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
config
,
**
kwargs
)
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_attentions
=
config
.
output_attentions
self
.
output_past
=
config
.
output_past
...
...
src/transformers/modeling_tf_distilbert.py
View file @
ba281707
...
...
@@ -24,7 +24,7 @@ import tensorflow as tf
from
.configuration_distilbert
import
DistilBertConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
TFMainLayer
,
TFPreTrainedModel
,
TFSharedEmbeddings
,
get_initializer
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -397,9 +397,9 @@ class TFTransformer(tf.keras.layers.Layer):
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
class
TFDistilBertMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFDistilBertMainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
config
,
**
kwargs
)
self
.
num_hidden_layers
=
config
.
num_hidden_layers
self
.
embeddings
=
TFEmbeddings
(
config
,
name
=
"embeddings"
)
# Embeddings
...
...
src/transformers/modeling_tf_gpt2.py
View file @
ba281707
...
...
@@ -25,6 +25,7 @@ from .configuration_gpt2 import GPT2Config
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
(
TFConv1D
,
TFMainLayer
,
TFPreTrainedModel
,
TFSequenceSummary
,
TFSharedEmbeddings
,
...
...
@@ -196,9 +197,9 @@ class TFBlock(tf.keras.layers.Layer):
return
outputs
# x, present, (attentions)
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFGPT2MainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_attentions
=
config
.
output_attentions
self
.
num_hidden_layers
=
config
.
n_layer
...
...
src/transformers/modeling_tf_openai.py
View file @
ba281707
...
...
@@ -25,6 +25,7 @@ from .configuration_openai import OpenAIGPTConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
(
TFConv1D
,
TFMainLayer
,
TFPreTrainedModel
,
TFSequenceSummary
,
TFSharedEmbeddings
,
...
...
@@ -197,7 +198,7 @@ class TFBlock(tf.keras.layers.Layer):
return
outputs
# x, (attentions)
class
TFOpenAIGPTMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFOpenAIGPTMainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
output_hidden_states
=
config
.
output_hidden_states
...
...
src/transformers/modeling_tf_t5.py
View file @
ba281707
...
...
@@ -25,7 +25,7 @@ import tensorflow as tf
from
.configuration_t5
import
T5Config
from
.file_utils
import
DUMMY_INPUTS
,
DUMMY_MASK
,
add_start_docstrings
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
from
.modeling_tf_utils
import
TFMainLayer
,
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -359,9 +359,9 @@ class TFT5Block(tf.keras.layers.Layer):
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
####################################################
class
TFT5MainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFT5MainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
config
,
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
is_decoder
=
config
.
is_decoder
...
...
src/transformers/modeling_tf_transfo_xl.py
View file @
ba281707
...
...
@@ -24,7 +24,7 @@ import tensorflow as tf
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_transfo_xl_utilities
import
TFAdaptiveSoftmaxMask
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
TFMainLayer
,
TFPreTrainedModel
,
get_initializer
,
shape_list
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -378,9 +378,9 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
return
embed
class
TFTransfoXLMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFTransfoXLMainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
config
,
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
...
...
src/transformers/modeling_tf_utils.py
View file @
ba281707
...
...
@@ -47,6 +47,23 @@ class TFModelUtilsMixin:
return
self
.
count_params
()
class
TFMainLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization.
"""
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
isinstance
(
config
,
dict
):
config
=
PretrainedConfig
.
from_dict
(
config
)
self
.
_transformers_config
=
config
def
get_config
(
self
):
cfg
=
super
().
get_config
()
cfg
[
"config"
]
=
self
.
_transformers_config
.
to_dict
()
return
cfg
class
TFPreTrainedModel
(
tf
.
keras
.
Model
,
TFModelUtilsMixin
):
r
""" Base class for all TF models.
...
...
src/transformers/modeling_tf_xlm.py
View file @
ba281707
...
...
@@ -25,7 +25,14 @@ import tensorflow as tf
from
.configuration_xlm
import
XLMConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSequenceSummary
,
TFSharedEmbeddings
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
(
TFMainLayer
,
TFPreTrainedModel
,
TFSequenceSummary
,
TFSharedEmbeddings
,
get_initializer
,
shape_list
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -196,9 +203,9 @@ class TFTransformerFFN(tf.keras.layers.Layer):
return
x
class
TFXLMMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFXLMMainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
config
,
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
...
...
src/transformers/modeling_tf_xlnet.py
View file @
ba281707
...
...
@@ -24,7 +24,14 @@ import tensorflow as tf
from
.configuration_xlnet
import
XLNetConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSequenceSummary
,
TFSharedEmbeddings
,
get_initializer
,
shape_list
from
.modeling_tf_utils
import
(
TFMainLayer
,
TFPreTrainedModel
,
TFSequenceSummary
,
TFSharedEmbeddings
,
get_initializer
,
shape_list
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -342,9 +349,9 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
return
hidden_states
class
TFXLNetMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFXLNetMainLayer
(
TFMain
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
config
,
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_past
=
config
.
output_past
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment