Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c9e8c519
Commit
c9e8c519
authored
Oct 10, 2019
by
thomwolf
Browse files
fixing SequenceSummary head in TF 2.0
parent
da26bae6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
16 deletions
+16
-16
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+16
-16
No files found.
transformers/modeling_tf_utils.py
View file @
c9e8c519
...
@@ -394,8 +394,8 @@ class TFSequenceSummary(tf.keras.layers.Layer):
...
@@ -394,8 +394,8 @@ class TFSequenceSummary(tf.keras.layers.Layer):
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise
NotImplementedError
raise
NotImplementedError
self
.
summary
=
None
self
.
has_
summary
=
hasattr
(
config
,
'summary_use_proj'
)
and
config
.
summary_use_proj
if
hasattr
(
config
,
'summary_use_proj'
)
and
config
.
summary_use_proj
:
if
self
.
has_summary
:
if
hasattr
(
config
,
'summary_proj_to_labels'
)
and
config
.
summary_proj_to_labels
and
config
.
num_labels
>
0
:
if
hasattr
(
config
,
'summary_proj_to_labels'
)
and
config
.
summary_proj_to_labels
and
config
.
num_labels
>
0
:
num_classes
=
config
.
num_labels
num_classes
=
config
.
num_labels
else
:
else
:
...
@@ -404,16 +404,16 @@ class TFSequenceSummary(tf.keras.layers.Layer):
...
@@ -404,16 +404,16 @@ class TFSequenceSummary(tf.keras.layers.Layer):
kernel_initializer
=
get_initializer
(
initializer_range
),
kernel_initializer
=
get_initializer
(
initializer_range
),
name
=
'summary'
)
name
=
'summary'
)
self
.
activation
=
None
self
.
has_
activation
=
hasattr
(
config
,
'summary_activation'
)
and
config
.
summary_activation
==
'tanh'
if
hasattr
(
config
,
'summary_activation'
)
and
config
.
summary
_activation
==
'tanh'
:
if
self
.
has
_activation
:
self
.
activation
=
tf
.
keras
.
activations
.
tanh
self
.
activation
=
tf
.
keras
.
activations
.
tanh
self
.
first_dropout
=
None
self
.
has_
first_dropout
=
hasattr
(
config
,
'summary_first_dropout'
)
and
config
.
summary_first_dropout
>
0
if
hasattr
(
config
,
'summary_first_dropout'
)
and
config
.
summary
_first_dropout
>
0
:
if
self
.
has
_first_dropout
:
self
.
first_dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
summary_first_dropout
)
self
.
first_dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
summary_first_dropout
)
self
.
last_dropout
=
None
self
.
has_
last_dropout
=
hasattr
(
config
,
'summary_last_dropout'
)
and
config
.
summary_last_dropout
>
0
if
hasattr
(
config
,
'summary_last_dropout'
)
and
config
.
summary
_last_dropout
>
0
:
if
self
.
has
_last_dropout
:
self
.
last_dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
summary_last_dropout
)
self
.
last_dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
summary_last_dropout
)
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
...
@@ -456,17 +456,17 @@ class TFSequenceSummary(tf.keras.layers.Layer):
...
@@ -456,17 +456,17 @@ class TFSequenceSummary(tf.keras.layers.Layer):
elif
self
.
summary_type
==
'attn'
:
elif
self
.
summary_type
==
'attn'
:
raise
NotImplementedError
raise
NotImplementedError
if
training
and
self
.
first_dropout
is
not
None
:
if
self
.
has_
first_dropout
:
output
=
self
.
first_dropout
(
output
)
output
=
self
.
first_dropout
(
output
,
training
=
training
)
if
self
.
summary
is
not
None
:
if
self
.
has_
summary
:
output
=
self
.
summary
(
output
)
output
=
self
.
summary
(
output
)
if
self
.
activation
is
not
None
:
if
self
.
has_
activation
:
output
=
self
.
activation
(
output
)
output
=
self
.
activation
(
output
)
if
training
and
self
.
last_dropout
is
not
None
:
if
self
.
has_
last_dropout
:
output
=
self
.
last_dropout
(
output
)
output
=
self
.
last_dropout
(
output
,
training
=
training
)
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment