Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
600a4232
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0e36e515154f686e3927eb6269cc0b80d4669ba1"
Commit
600a4232
authored
Sep 05, 2019
by
thomwolf
Browse files
add weights tying, attention and hidden states output tests
parent
04d2006f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
125 additions
and
83 deletions
+125
-83
pytorch_transformers/modeling_tf_bert.py
pytorch_transformers/modeling_tf_bert.py
+58
-23
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+19
-6
pytorch_transformers/tests/modeling_tf_common_test.py
pytorch_transformers/tests/modeling_tf_common_test.py
+48
-54
No files found.
pytorch_transformers/modeling_tf_bert.py
View file @
600a4232
...
...
@@ -141,7 +141,9 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
"""
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFBertEmbeddings
,
self
).
__init__
(
**
kwargs
)
self
.
word_embeddings
=
tf
.
keras
.
layers
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
name
=
'word_embeddings'
)
self
.
vocab_size
=
config
.
vocab_size
self
.
hidden_size
=
config
.
hidden_size
self
.
position_embeddings
=
tf
.
keras
.
layers
.
Embedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
,
name
=
'position_embeddings'
)
self
.
token_type_embeddings
=
tf
.
keras
.
layers
.
Embedding
(
config
.
type_vocab_size
,
config
.
hidden_size
,
name
=
'token_type_embeddings'
)
...
...
@@ -150,8 +152,44 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self
.
LayerNorm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_eps
,
name
=
'LayerNorm'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
hidden_dropout_prob
)
def
build
(
self
,
input_shape
):
"""Build shared word embedding layer """
with
tf
.
name_scope
(
"word_embeddings"
):
# Create and initialize weights. The random normal initializer was chosen
# arbitrarily, and works well.
self
.
word_embeddings
=
self
.
add_weight
(
"weight"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
super
(
TFBertEmbeddings
,
self
).
build
(
input_shape
)
@
tf
.
function
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
mode
=
"embedding"
,
training
=
False
):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if
mode
==
"embedding"
:
return
self
.
_embedding
(
inputs
,
training
=
training
)
elif
mode
==
"linear"
:
return
self
.
_linear
(
inputs
)
else
:
raise
ValueError
(
"mode {} is not valid."
.
format
(
mode
))
def
_embedding
(
self
,
inputs
,
training
=
False
):
"""Applies embedding based on inputs tensor."""
# Create binary mask of size [batch_size, length]
input_ids
,
position_ids
,
token_type_ids
=
inputs
seq_length
=
tf
.
shape
(
input_ids
)[
1
]
...
...
@@ -160,7 +198,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
fill
(
tf
.
shape
(
input_ids
),
0
)
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
words_embeddings
=
tf
.
gather
(
self
.
word_embeddings
,
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
...
...
@@ -170,6 +208,21 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
embeddings
=
self
.
dropout
(
embeddings
)
return
embeddings
def
_linear
(
self
,
inputs
):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
"""
batch_size
=
tf
.
shape
(
inputs
)[
0
]
length
=
tf
.
shape
(
inputs
)[
1
]
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
logits
=
tf
.
matmul
(
x
,
self
.
word_embeddings
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
self
.
vocab_size
])
class
TFBertSelfAttention
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
...
...
@@ -448,8 +501,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self
.
encoder
=
TFBertEncoder
(
config
,
name
=
'encoder'
)
self
.
pooler
=
TFBertPooler
(
config
,
name
=
'pooler'
)
# self.apply(self.init_weights) # TODO check weights initialization
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
raise
NotImplementedError
...
...
@@ -692,22 +743,14 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
super
(
TFBertForPreTraining
,
self
).
__init__
(
config
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
cls_mlm
=
TFBertMLMHead
(
config
,
name
=
'cls_mlm'
)
self
.
cls_nsp
=
TFBertNSPHead
(
config
,
name
=
'cls_nsp'
)
self
.
tie_weights
()
def
tie_weights
(
self
):
""" Make sure we are sharing the input and output embeddings.
"""
pass
# TODO add weights tying
@
tf
.
function
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
=
self
.
cls_mlm
(
sequence_output
)
prediction_scores
=
self
.
bert
.
embeddings
(
sequence_output
,
mode
=
"linear"
,
training
=
training
)
seq_relationship_score
=
self
.
cls_nsp
(
pooled_output
)
outputs
=
(
prediction_scores
,
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
...
@@ -751,21 +794,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
super
(
TFBertForMaskedLM
,
self
).
__init__
(
config
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
cls_mlm
=
TFBertMLMHead
(
config
,
name
=
'cls_mlm'
)
self
.
tie_weights
()
def
tie_weights
(
self
):
""" Make sure we are sharing the input and output embeddings.
"""
pass
# TODO add weights tying
@
tf
.
function
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
cls_mlm
(
sequence_output
)
prediction_scores
=
self
.
bert
.
embeddings
(
sequence_output
,
mode
=
"linear"
,
training
=
training
)
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
...
...
pytorch_transformers/modeling_tf_utils.py
View file @
600a4232
...
...
@@ -64,7 +64,7 @@ class TFPreTrainedModel(tf.keras.Model):
self
.
config
=
config
def
_get_resized_embeddings
(
self
,
old_embeddings
,
new_num_tokens
=
None
):
""" Build a resized Embedding
Modu
le from a provided token Embedding Module.
""" Build a resized Embedding
Variab
le from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
...
...
@@ -77,12 +77,25 @@ class TFPreTrainedModel(tf.keras.Model):
Return: ``torch.nn.Embeddings``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
raise
NotImplementedError
# if new_num_tokens is None:
# return old_embeddings
def
_tie_or_clone_weights
(
self
,
first_module
,
second_module
):
""" Tie or clone module weights depending of weither we are using TorchScript or not
"""
raise
NotImplementedError
# old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
# if old_num_tokens == new_num_tokens:
# return old_embeddings
# # Build new embeddings
# new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
# new_embeddings.to(old_embeddings.weight.device)
# # initialize all new embeddings (in particular added tokens)
# self._init_weights(new_embeddings)
# # Copy word embeddings from the previous weights
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# return new_embeddings
def
resize_token_embeddings
(
self
,
new_num_tokens
=
None
):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
...
...
pytorch_transformers/tests/modeling_tf_common_test.py
View file @
600a4232
...
...
@@ -64,44 +64,40 @@ class TFCommonTestCases:
def
test_attention_outputs
(
self
):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# for model_class in self.all_model_classes:
# config.output_attentions = True
# config.output_hidden_states = False
# model = model_class(config)
# model.eval()
# outputs = model(**inputs_dict)
# attentions = outputs[-1]
# self.assertEqual(model.config.output_attentions, True)
# self.assertEqual(model.config.output_hidden_states, False)
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# self.assertListEqual(
# list(attentions[0].shape[-3:]),
# [self.model_tester.num_attention_heads,
# self.model_tester.seq_length,
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
# out_len = len(outputs)
# # Check attention is always last and order is fine
# config.output_attentions = True
# config.output_hidden_states = True
# model = model_class(config)
# model.eval()
# outputs = model(**inputs_dict)
# self.assertEqual(out_len+1, len(outputs))
# self.assertEqual(model.config.output_attentions, True)
# self.assertEqual(model.config.output_hidden_states, True)
# attentions = outputs[-1]
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# self.assertListEqual(
# list(attentions[0].shape[-3:]),
# [self.model_tester.num_attention_heads,
# self.model_tester.seq_length,
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
False
model
=
model_class
(
config
)
outputs
=
model
(
inputs_dict
)
attentions
=
[
t
.
numpy
()
for
t
in
outputs
[
-
1
]]
self
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
False
)
self
.
assertEqual
(
len
(
attentions
),
self
.
model_tester
.
num_hidden_layers
)
self
.
assertListEqual
(
list
(
attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
key_len
if
hasattr
(
self
.
model_tester
,
'key_len'
)
else
self
.
model_tester
.
seq_length
])
out_len
=
len
(
outputs
)
# Check attention is always last and order is fine
config
.
output_attentions
=
True
config
.
output_hidden_states
=
True
model
=
model_class
(
config
)
outputs
=
model
(
inputs_dict
)
self
.
assertEqual
(
out_len
+
1
,
len
(
outputs
))
self
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
attentions
=
[
t
.
numpy
()
for
t
in
outputs
[
-
1
]]
self
.
assertEqual
(
len
(
attentions
),
self
.
model_tester
.
num_hidden_layers
)
self
.
assertListEqual
(
list
(
attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
key_len
if
hasattr
(
self
.
model_tester
,
'key_len'
)
else
self
.
model_tester
.
seq_length
])
def
test_headmasking
(
self
):
pass
...
...
@@ -178,22 +174,20 @@ class TFCommonTestCases:
def
test_hidden_states_output
(
self
):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# for model_class in self.all_model_classes:
# config.output_hidden_states = True
# config.output_attentions = False
# model = model_class(config)
# model.eval()
# outputs = model(**inputs_dict)
# hidden_states = outputs[-1]
# self.assertEqual(model.config.output_attentions, False)
# self.assertEqual(model.config.output_hidden_states, True)
# self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
# self.assertListEqual(
# list(hidden_states[0].shape[-2:]),
# [self.model_tester.seq_length, self.model_tester.hidden_size])
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
config
.
output_hidden_states
=
True
config
.
output_attentions
=
False
model
=
model_class
(
config
)
outputs
=
model
(
inputs_dict
)
hidden_states
=
[
t
.
numpy
()
for
t
in
outputs
[
-
1
]]
self
.
assertEqual
(
model
.
config
.
output_attentions
,
False
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
self
.
assertEqual
(
len
(
hidden_states
),
self
.
model_tester
.
num_hidden_layers
+
1
)
self
.
assertListEqual
(
list
(
hidden_states
[
0
].
shape
[
-
2
:]),
[
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
hidden_size
])
def
test_resize_tokens_embeddings
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment