Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
470753bc
Commit
470753bc
authored
Mar 03, 2020
by
Gunnlaugur Thor Briem
Browse files
Put @keras_serializable only on layers it works on
And only run the test on TF*MainLayer classes so marked.
parent
0c716ede
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
2 additions
and
7 deletions
+2
-7
src/transformers/modeling_tf_distilbert.py
src/transformers/modeling_tf_distilbert.py
+0
-1
src/transformers/modeling_tf_openai.py
src/transformers/modeling_tf_openai.py
+0
-1
src/transformers/modeling_tf_t5.py
src/transformers/modeling_tf_t5.py
+0
-1
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+1
-0
src/transformers/modeling_tf_xlm.py
src/transformers/modeling_tf_xlm.py
+0
-1
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+1
-3
No files found.
src/transformers/modeling_tf_distilbert.py
View file @
470753bc
...
@@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer):
...
@@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer):
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
@
keras_serializable
class
TFDistilBertMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFDistilBertMainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
...
...
src/transformers/modeling_tf_openai.py
View file @
470753bc
...
@@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer):
...
@@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer):
return
outputs
# x, (attentions)
return
outputs
# x, (attentions)
@
keras_serializable
class
TFOpenAIGPTMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFOpenAIGPTMainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
(
*
inputs
,
**
kwargs
)
super
().
__init__
(
*
inputs
,
**
kwargs
)
...
...
src/transformers/modeling_tf_t5.py
View file @
470753bc
...
@@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer):
...
@@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer):
# The full model without a specific pretrained or finetuning head is
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
####################################################
####################################################
@
keras_serializable
class
TFT5MainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFT5MainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
...
...
src/transformers/modeling_tf_utils.py
View file @
470753bc
...
@@ -71,6 +71,7 @@ def keras_serializable(cls):
...
@@ -71,6 +71,7 @@ def keras_serializable(cls):
cls
.
get_config
=
get_config
cls
.
get_config
=
get_config
cls
.
_keras_serializable
=
True
return
tf
.
keras
.
utils
.
register_keras_serializable
()(
cls
)
return
tf
.
keras
.
utils
.
register_keras_serializable
()(
cls
)
...
...
src/transformers/modeling_tf_xlm.py
View file @
470753bc
...
@@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer):
...
@@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer):
return
x
return
x
@
keras_serializable
class
TFXLMMainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFXLMMainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
...
...
tests/test_modeling_tf_common.py
View file @
470753bc
...
@@ -103,11 +103,9 @@ class TFModelTesterMixin:
...
@@ -103,11 +103,9 @@ class TFModelTesterMixin:
if
module_member_name
.
endswith
(
"MainLayer"
)
if
module_member_name
.
endswith
(
"MainLayer"
)
for
module_member
in
(
getattr
(
module
,
module_member_name
),)
for
module_member
in
(
getattr
(
module
,
module_member_name
),)
if
isinstance
(
module_member
,
type
)
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
if
isinstance
(
module_member
,
type
)
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
getattr
(
module_member
,
'_keras_serializable'
,
False
)
)
)
for
main_layer_class
in
tf_main_layer_classes
:
for
main_layer_class
in
tf_main_layer_classes
:
if
main_layer_class
.
__name__
==
"TFT5MainLayer"
:
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
continue
main_layer
=
main_layer_class
(
config
)
main_layer
=
main_layer_class
(
config
)
symbolic_inputs
=
{
symbolic_inputs
=
{
name
:
tf
.
keras
.
Input
(
tensor
.
shape
[
1
:],
dtype
=
tensor
.
dtype
)
for
name
,
tensor
in
inputs_dict
.
items
()
name
:
tf
.
keras
.
Input
(
tensor
.
shape
[
1
:],
dtype
=
tensor
.
dtype
)
for
name
,
tensor
in
inputs_dict
.
items
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment