Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b18509c2
Commit
b18509c2
authored
Nov 08, 2019
by
Lysandre
Committed by
Lysandre Debut
Nov 26, 2019
Browse files
Tests for ALBERT in TF2 + fixes
parent
7bddbf59
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
237 additions
and
11 deletions
+237
-11
transformers/__init__.py
transformers/__init__.py
+1
-0
transformers/modeling_tf_albert.py
transformers/modeling_tf_albert.py
+7
-11
transformers/tests/modeling_tf_albert_test.py
transformers/tests/modeling_tf_albert_test.py
+229
-0
No files found.
transformers/__init__.py
View file @
b18509c2
...
@@ -169,6 +169,7 @@ if is_tf_available():
...
@@ -169,6 +169,7 @@ if is_tf_available():
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_albert
import
(
TFAlbertPreTrainedModel
,
TFAlbertModel
,
TFAlbertForMaskedLM
,
from
.modeling_tf_albert
import
(
TFAlbertPreTrainedModel
,
TFAlbertModel
,
TFAlbertForMaskedLM
,
TFAlbertForSequenceClassification
,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
# TF 2.0 <=> PyTorch conversion utilities
# TF 2.0 <=> PyTorch conversion utilities
...
...
transformers/modeling_tf_albert.py
View file @
b18509c2
...
@@ -126,16 +126,8 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
...
@@ -126,16 +126,8 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
"""
"""
batch_size
=
tf
.
shape
(
inputs
)[
0
]
batch_size
=
tf
.
shape
(
inputs
)[
0
]
length
=
tf
.
shape
(
inputs
)[
1
]
length
=
tf
.
shape
(
inputs
)[
1
]
print
(
inputs
.
shape
)
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
config
.
embedding_size
])
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
config
.
embedding_size
])
print
(
x
.
shape
,
self
.
word_embeddings
)
logits
=
tf
.
matmul
(
x
,
self
.
word_embeddings
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
x
,
self
.
word_embeddings
,
transpose_b
=
True
)
print
([
batch_size
,
length
,
self
.
config
.
vocab_size
])
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
self
.
config
.
vocab_size
])
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
self
.
config
.
vocab_size
])
...
@@ -460,20 +452,24 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
...
@@ -460,20 +452,24 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
# The output weights are the same as the input embeddings, but there is
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
# an output-only bias for each token.
self
.
input_embeddings
=
input_embeddings
self
.
decoder
=
input_embeddings
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
self
.
bias
=
self
.
add_weight
(
shape
=
(
self
.
vocab_size
,),
self
.
bias
=
self
.
add_weight
(
shape
=
(
self
.
vocab_size
,),
initializer
=
'zeros'
,
initializer
=
'zeros'
,
trainable
=
True
,
trainable
=
True
,
name
=
'bias'
)
name
=
'bias'
)
self
.
decoder_bias
=
self
.
add_weight
(
shape
=
(
self
.
vocab_size
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'decoder/bias'
)
super
(
TFAlbertMLMHead
,
self
).
build
(
input_shape
)
super
(
TFAlbertMLMHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
hidden_states
):
def
call
(
self
,
hidden_states
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
activation
(
hidden_states
)
hidden_states
=
self
.
activation
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
)
hidden_states
=
self
.
input_embeddings
(
hidden_states
,
mode
=
"linear"
)
hidden_states
=
self
.
decoder
(
hidden_states
,
mode
=
"linear"
)
+
self
.
decoder_bias
hidden_states
=
hidden_states
+
self
.
bias
hidden_states
=
hidden_states
+
self
.
bias
return
hidden_states
return
hidden_states
...
@@ -666,7 +662,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
...
@@ -666,7 +662,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
[
embedding_output
,
extended_attention_mask
,
head_mask
],
training
=
training
)
[
embedding_output
,
extended_attention_mask
,
head_mask
],
training
=
training
)
sequence_output
=
encoder_outputs
[
0
]
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
pooled_output
=
self
.
pooler
(
sequence_output
[:,
0
]
)
# add hidden_states and attentions if they are here
# add hidden_states and attentions if they are here
outputs
=
(
sequence_output
,
pooled_output
,)
+
encoder_outputs
[
1
:]
outputs
=
(
sequence_output
,
pooled_output
,)
+
encoder_outputs
[
1
:]
...
...
transformers/tests/modeling_tf_albert_test.py
0 → 100644
View file @
b18509c2
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
shutil
import
pytest
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
transformers
import
AlbertConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers.modeling_tf_albert
import
(
TFAlbertModel
,
TFAlbertForMaskedLM
,
TFAlbertForSequenceClassification
,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
class
TFAlbertModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFAlbertModel
,
TFAlbertForMaskedLM
,
TFAlbertForSequenceClassification
)
if
is_tf_available
()
else
()
class
TFAlbertModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_input_mask
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_input_mask
=
use_input_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
(
[
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
(
[
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
(
[
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
(
[
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
(
[
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
AlbertConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
intermediate_size
,
hidden_act
=
self
.
hidden_act
,
hidden_dropout_prob
=
self
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_albert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFAlbertModel
(
config
=
config
)
# inputs = {'input_ids': input_ids,
# 'attention_mask': input_mask,
# 'token_type_ids': token_type_ids}
# sequence_output, pooled_output = model(**inputs)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
sequence_output
,
pooled_output
=
model
(
inputs
)
inputs
=
[
input_ids
,
input_mask
]
sequence_output
,
pooled_output
=
model
(
inputs
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
.
numpy
(),
"pooled_output"
:
pooled_output
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
shape
),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_albert_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFAlbertForMaskedLM
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
prediction_scores
,
=
model
(
inputs
)
result
=
{
"prediction_scores"
:
prediction_scores
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_albert_for_sequence_classification
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_labels
=
self
.
num_labels
model
=
TFAlbertForSequenceClassification
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
logits
,
=
model
(
inputs
)
result
=
{
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_labels
])
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFAlbertModelTest
.
TFAlbertModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
AlbertConfig
,
hidden_size
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_albert_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_albert_model
(
*
config_and_inputs
)
def
test_for_masked_lm
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_albert_for_masked_lm
(
*
config_and_inputs
)
def
test_for_sequence_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_albert_for_sequence_classification
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
# for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for
model_name
in
[
'albert-base-uncased'
]:
model
=
TFAlbertModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment