Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0397619a
Unverified
Commit
0397619a
authored
Oct 22, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 22, 2020
Browse files
Move NoLayerEmbedTokens (#7945)
* Move NoLayerEmbedTokens * TFWrappedEmbeddings * Add comment
parent
5ac07513
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
67 deletions
+43
-67
src/transformers/modeling_tf_bart.py
src/transformers/modeling_tf_bart.py
+3
-31
src/transformers/modeling_tf_t5.py
src/transformers/modeling_tf_t5.py
+10
-36
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+30
-0
No files found.
src/transformers/modeling_tf_bart.py
View file @
0397619a
...
...
@@ -33,6 +33,7 @@ from .modeling_tf_utils import (
DUMMY_INPUTS
,
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFWrappedEmbeddings
,
cast_bool_to_primitive
,
keras_serializable
,
shape_list
,
...
...
@@ -132,36 +133,6 @@ LARGE_NEGATIVE = -1e8
logger
=
logging
.
get_logger
(
__name__
)
class
_NoLayerEmbedTokens
:
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
class to avoid problem with weight restoring. Also it makes sure that the layer is
called from the correct scope to avoid problem with saving/storing the correct weights
"""
def
__init__
(
self
,
layer
,
abs_scope_name
=
None
):
self
.
_layer
=
layer
self
.
_abs_scope_name
=
abs_scope_name
def
call
(
self
,
inputs
,
mode
=
"embedding"
):
if
self
.
_abs_scope_name
is
None
:
return
self
.
_layer
.
call
(
inputs
,
mode
)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with
tf
.
compat
.
v1
.
variable_scope
(
self
.
_abs_scope_name
,
auxiliary_name_scope
=
False
)
as
abs_scope_name
:
with
tf
.
name_scope
(
abs_scope_name
.
original_name_scope
):
return
self
.
_layer
.
call
(
inputs
,
mode
)
def
__call__
(
self
,
inputs
,
mode
=
"embedding"
):
if
self
.
_abs_scope_name
is
None
:
return
self
.
_layer
(
inputs
,
mode
)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with
tf
.
compat
.
v1
.
variable_scope
(
self
.
_abs_scope_name
,
auxiliary_name_scope
=
False
)
as
abs_scope_name
:
with
tf
.
name_scope
(
abs_scope_name
.
original_name_scope
):
return
self
.
_layer
(
inputs
,
mode
)
def
create_position_ids_from_input_ids
(
input_ids
,
padding_idx
):
"""Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
...
...
@@ -826,7 +797,8 @@ class TFBartModel(TFPretrainedBartModel):
with
tf
.
compat
.
v1
.
variable_scope
(
"model.shared"
)
as
shared_abs_scope_name
:
pass
embed_tokens
=
_NoLayerEmbedTokens
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens
=
TFWrappedEmbeddings
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
embed_tokens
.
vocab_size
=
self
.
shared
.
vocab_size
embed_tokens
.
hidden_size
=
self
.
shared
.
hidden_size
...
...
src/transformers/modeling_tf_t5.py
View file @
0397619a
...
...
@@ -24,6 +24,8 @@ from typing import Tuple
import
tensorflow
as
tf
from
transformers.modeling_tf_utils
import
TFWrappedEmbeddings
from
.configuration_t5
import
T5Config
from
.file_utils
import
(
DUMMY_INPUTS
,
...
...
@@ -505,36 +507,6 @@ class TFT5Block(tf.keras.layers.Layer):
return
outputs
# hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class
_NoLayerEmbedTokens
:
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
class to avoid problem with weight restoring. Also it makes sure that the layer is
called from the correct scope to avoid problem with saving/storing the correct weights
"""
def
__init__
(
self
,
layer
,
abs_scope_name
=
None
):
self
.
_layer
=
layer
self
.
_abs_scope_name
=
abs_scope_name
def
call
(
self
,
inputs
,
mode
=
"embedding"
):
if
self
.
_abs_scope_name
is
None
:
return
self
.
_layer
.
call
(
inputs
,
mode
)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with
tf
.
compat
.
v1
.
variable_scope
(
self
.
_abs_scope_name
,
auxiliary_name_scope
=
False
)
as
abs_scope_name
:
with
tf
.
name_scope
(
abs_scope_name
.
original_name_scope
):
return
self
.
_layer
.
call
(
inputs
,
mode
)
def
__call__
(
self
,
inputs
,
mode
=
"embedding"
):
if
self
.
_abs_scope_name
is
None
:
return
self
.
_layer
(
inputs
,
mode
)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with
tf
.
compat
.
v1
.
variable_scope
(
self
.
_abs_scope_name
,
auxiliary_name_scope
=
False
)
as
abs_scope_name
:
with
tf
.
name_scope
(
abs_scope_name
.
original_name_scope
):
return
self
.
_layer
(
inputs
,
mode
)
####################################################
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
...
...
@@ -980,8 +952,8 @@ class TFT5Model(TFT5PreTrainedModel):
# retrieve correct absolute scope for embed token wrapper
with
tf
.
compat
.
v1
.
variable_scope
(
"shared"
)
as
shared_abs_scope_name
:
pass
embed_tokens
=
_NoLayerEmbedToken
s
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens
=
TFWrappedEmbedding
s
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
encoder_config
=
copy
.
deepcopy
(
config
)
encoder_config
.
use_cache
=
False
...
...
@@ -1003,7 +975,8 @@ class TFT5Model(TFT5PreTrainedModel):
# retrieve correct absolute scope for embed token wrapper
with
tf
.
compat
.
v1
.
variable_scope
(
"shared"
)
as
shared_abs_scope_name
:
pass
embed_tokens
=
_NoLayerEmbedTokens
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens
=
TFWrappedEmbeddings
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
self
.
encoder
.
set_embed_tokens
(
embed_tokens
)
self
.
decoder
.
set_embed_tokens
(
embed_tokens
)
...
...
@@ -1177,8 +1150,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# retrieve correct absolute scope for embed token wrapper
with
tf
.
compat
.
v1
.
variable_scope
(
"shared"
)
as
shared_abs_scope_name
:
pass
embed_tokens
=
_NoLayerEmbedToken
s
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens
=
TFWrappedEmbedding
s
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
encoder_config
=
copy
.
deepcopy
(
config
)
encoder_config
.
use_cache
=
False
...
...
@@ -1199,7 +1172,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# retrieve correct absolute scope for embed token wrapper
with
tf
.
compat
.
v1
.
variable_scope
(
"shared"
)
as
shared_abs_scope_name
:
pass
embed_tokens
=
_NoLayerEmbedTokens
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens
=
TFWrappedEmbeddings
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
self
.
encoder
.
set_embed_tokens
(
embed_tokens
)
self
.
decoder
.
set_embed_tokens
(
embed_tokens
)
...
...
src/transformers/modeling_tf_utils.py
View file @
0397619a
...
...
@@ -1065,3 +1065,33 @@ def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor
# else variable is bool
return
bool_variable
class
TFWrappedEmbeddings
:
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
class to avoid problem with weight restoring. Also it makes sure that the layer is
called from the correct scope to avoid problem with saving/storing the correct weights
"""
def
__init__
(
self
,
layer
,
abs_scope_name
=
None
):
self
.
_layer
=
layer
self
.
_abs_scope_name
=
abs_scope_name
def
call
(
self
,
inputs
,
mode
=
"embedding"
):
if
self
.
_abs_scope_name
is
None
:
return
self
.
_layer
.
call
(
inputs
,
mode
)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with
tf
.
compat
.
v1
.
variable_scope
(
self
.
_abs_scope_name
,
auxiliary_name_scope
=
False
)
as
abs_scope_name
:
with
tf
.
name_scope
(
abs_scope_name
.
original_name_scope
):
return
self
.
_layer
.
call
(
inputs
,
mode
)
def
__call__
(
self
,
inputs
,
mode
=
"embedding"
):
if
self
.
_abs_scope_name
is
None
:
return
self
.
_layer
(
inputs
,
mode
)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with
tf
.
compat
.
v1
.
variable_scope
(
self
.
_abs_scope_name
,
auxiliary_name_scope
=
False
)
as
abs_scope_name
:
with
tf
.
name_scope
(
abs_scope_name
.
original_name_scope
):
return
self
.
_layer
(
inputs
,
mode
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment