Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2918b7d2
Commit
2918b7d2
authored
Jul 12, 2019
by
thomwolf
Browse files
updating tests
parent
3fbceed8
Changes
14
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
701 additions
and
625 deletions
+701
-625
pytorch_transformers/modeling_bert.py
pytorch_transformers/modeling_bert.py
+5
-10
pytorch_transformers/modeling_gpt2.py
pytorch_transformers/modeling_gpt2.py
+5
-10
pytorch_transformers/modeling_openai.py
pytorch_transformers/modeling_openai.py
+5
-10
pytorch_transformers/modeling_transfo_xl.py
pytorch_transformers/modeling_transfo_xl.py
+15
-4
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+39
-8
pytorch_transformers/modeling_xlm.py
pytorch_transformers/modeling_xlm.py
+6
-4
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+7
-6
pytorch_transformers/tests/modeling_bert_test.py
pytorch_transformers/tests/modeling_bert_test.py
+47
-39
pytorch_transformers/tests/modeling_common_test.py
pytorch_transformers/tests/modeling_common_test.py
+471
-437
pytorch_transformers/tests/modeling_gpt2_test.py
pytorch_transformers/tests/modeling_gpt2_test.py
+4
-9
pytorch_transformers/tests/modeling_openai_test.py
pytorch_transformers/tests/modeling_openai_test.py
+3
-4
pytorch_transformers/tests/modeling_transfo_xl_test.py
pytorch_transformers/tests/modeling_transfo_xl_test.py
+29
-27
pytorch_transformers/tests/modeling_xlm_test.py
pytorch_transformers/tests/modeling_xlm_test.py
+27
-24
pytorch_transformers/tests/modeling_xlnet_test.py
pytorch_transformers/tests/modeling_xlnet_test.py
+38
-33
No files found.
pytorch_transformers/modeling_bert.py
View file @
2918b7d2
...
...
@@ -617,6 +617,7 @@ class BertModel(BertPreTrainedModel):
old_embeddings
=
self
.
embeddings
.
word_embeddings
new_embeddings
=
self
.
_get_resized_embeddings
(
old_embeddings
,
new_num_tokens
)
self
.
embeddings
.
word_embeddings
=
new_embeddings
return
self
.
embeddings
.
word_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
...
...
@@ -758,11 +759,8 @@ class BertForPreTraining(BertPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings
=
self
.
bert
.
embeddings
.
word_embeddings
.
weight
if
self
.
config
.
torchscript
:
self
.
cls
.
predictions
.
decoder
.
weight
=
nn
.
Parameter
(
input_embeddings
.
clone
())
else
:
self
.
cls
.
predictions
.
decoder
.
weight
=
input_embeddings
# Tied weights
self
.
_tie_or_clone_weights
(
self
.
cls
.
predictions
.
decoder
,
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
next_sentence_label
=
None
,
head_mask
=
None
):
...
...
@@ -864,11 +862,8 @@ class BertForMaskedLM(BertPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings
=
self
.
bert
.
embeddings
.
word_embeddings
.
weight
if
self
.
config
.
torchscript
:
self
.
cls
.
predictions
.
decoder
.
weight
=
nn
.
Parameter
(
input_embeddings
.
clone
())
else
:
self
.
cls
.
predictions
.
decoder
.
weight
=
input_embeddings
# Tied weights
self
.
_tie_or_clone_weights
(
self
.
cls
.
predictions
.
decoder
,
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
head_mask
=
None
):
"""
...
...
pytorch_transformers/modeling_gpt2.py
View file @
2918b7d2
...
...
@@ -414,6 +414,7 @@ class GPT2Model(GPT2PreTrainedModel):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
wte
=
self
.
_get_resized_embeddings
(
self
.
wte
,
new_num_tokens
)
return
self
.
wte
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
...
...
@@ -562,11 +563,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings
=
self
.
transformer
.
wte
.
weight
if
self
.
config
.
torchscript
:
self
.
lm_head
.
weight
=
nn
.
Parameter
(
input_embeddings
.
clone
())
else
:
self
.
lm_head
.
weight
=
input_embeddings
# Tied weights
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
wte
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
,
head_mask
=
None
):
"""
...
...
@@ -658,11 +656,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings
=
self
.
transformer
.
wte
.
weight
if
self
.
config
.
torchscript
:
self
.
lm_head
.
weight
=
nn
.
Parameter
(
input_embeddings
.
clone
())
else
:
self
.
lm_head
.
weight
=
input_embeddings
# Tied weights
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
wte
)
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
,
head_mask
=
None
):
...
...
pytorch_transformers/modeling_openai.py
View file @
2918b7d2
...
...
@@ -430,6 +430,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
tokens_embed
=
self
.
_get_resized_embeddings
(
self
.
tokens_embed
,
new_num_tokens
)
return
self
.
tokens_embed
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
...
...
@@ -583,11 +584,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings
=
self
.
transformer
.
tokens_embed
.
weight
if
self
.
config
.
torchscript
:
self
.
lm_head
.
weight
=
nn
.
Parameter
(
input_embeddings
.
clone
())
else
:
self
.
lm_head
.
weight
=
input_embeddings
# Tied weights
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
tokens_embed
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
head_mask
=
None
):
"""
...
...
@@ -696,11 +694,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings
=
self
.
transformer
.
tokens_embed
.
weight
if
self
.
config
.
torchscript
:
self
.
lm_head
.
weight
=
nn
.
Parameter
(
input_embeddings
.
clone
())
else
:
self
.
lm_head
.
weight
=
input_embeddings
# Tied weights
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
tokens_embed
)
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
...
...
pytorch_transformers/modeling_transfo_xl.py
View file @
2918b7d2
...
...
@@ -291,6 +291,10 @@ class TransfoXLConfig(PretrainedConfig):
def
vocab_size
(
self
):
return
self
.
n_token
@
vocab_size
.
setter
def
vocab_size
(
self
,
value
):
self
.
n_token
=
value
@
property
def
hidden_size
(
self
):
return
self
.
d_model
...
...
@@ -1003,7 +1007,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
r
aise
NotImplementedError
r
eturn
self
.
word_emb
def
backward_compatible
(
self
):
self
.
sample_softmax
=
-
1
...
...
@@ -1280,12 +1284,19 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else
:
if
self
.
config
.
tie_weight
:
for
i
in
range
(
len
(
self
.
crit
.
out_layers
)):
self
.
crit
.
out_layers
[
i
].
weight
=
self
.
transformer
.
word_emb
.
emb_layers
[
i
].
weight
self
.
_tie_or_clone_weights
(
self
.
crit
.
out_layers
[
i
],
self
.
transformer
.
word_emb
.
emb_layers
[
i
])
if
self
.
config
.
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
self
.
config
.
tie_projs
):
if
tie_proj
and
self
.
config
.
div_val
==
1
and
self
.
config
.
d_model
!=
self
.
config
.
d_embed
:
if
self
.
config
.
torchscript
:
self
.
crit
.
out_projs
[
i
]
=
nn
.
Parameter
(
self
.
transformer
.
word_emb
.
emb_projs
[
0
].
clone
())
else
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
0
]
elif
tie_proj
and
self
.
config
.
div_val
!=
1
:
if
self
.
config
.
torchscript
:
self
.
crit
.
out_projs
[
i
]
=
nn
.
Parameter
(
self
.
transformer
.
word_emb
.
emb_projs
[
i
].
clone
())
else
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
i
]
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
...
...
pytorch_transformers/modeling_utils.py
View file @
2918b7d2
...
...
@@ -165,9 +165,27 @@ class PreTrainedModel(nn.Module):
# Save config in model
self
.
config
=
config
def
_get_resized_embeddings
(
self
,
old_embeddings
,
new_num_tokens
):
# Build new embeddings
def
_get_resized_embeddings
(
self
,
old_embeddings
,
new_num_tokens
=
None
):
""" Build a resized Embedding Module from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
Args:
new_num_tokens: (Optional) New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return:
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
if
new_num_tokens
is
None
:
return
old_embeddings
old_num_tokens
,
old_embedding_dim
=
old_embeddings
.
weight
.
size
()
if
old_num_tokens
==
new_num_tokens
:
return
old_embeddings
# Build new embeddings
new_embeddings
=
nn
.
Embedding
(
new_num_tokens
,
old_embedding_dim
)
new_embeddings
.
to
(
old_embeddings
.
weight
.
device
)
...
...
@@ -180,18 +198,29 @@ class PreTrainedModel(nn.Module):
return
new_embeddings
def
resize_token_embeddings
(
self
,
new_num_tokens
):
""" Resize input token embeddings matrix.
def
_tie_or_clone_weights
(
self
,
first_module
,
second_module
):
""" Tie or clone module weights depending of weither we are using TorchScript or not
"""
if
self
.
config
.
torchscript
:
first_module
.
weight
=
nn
.
Parameter
(
second_module
.
weight
.
clone
())
else
:
first_module
.
weight
=
second_module
.
weight
def
resize_token_embeddings
(
self
,
new_num_tokens
=
None
):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Args:
new_num_tokens: New number of tokens in the embedding matrix.
new_num_tokens:
(Optional)
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: does nothing.
Return:
Pointer to the input tokens Embedding Module of the model
"""
if
new_num_tokens
==
self
.
config
.
vocab_size
:
return
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
# get the base model if needed
base_model
.
_resize_token_embeddings
(
new_num_tokens
)
model_embeds
=
base_model
.
_resize_token_embeddings
(
new_num_tokens
)
if
new_num_tokens
is
None
:
return
model_embeds
# Update base model and current model config
self
.
config
.
vocab_size
=
new_num_tokens
...
...
@@ -201,6 +230,8 @@ class PreTrainedModel(nn.Module):
if
hasattr
(
self
,
'tie_weights'
):
self
.
tie_weights
()
return
model_embeds
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the base model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
...
...
pytorch_transformers/modeling_xlm.py
View file @
2918b7d2
...
...
@@ -184,6 +184,10 @@ class XLMConfig(PretrainedConfig):
def
vocab_size
(
self
):
return
self
.
n_words
@
vocab_size
.
setter
def
vocab_size
(
self
,
value
):
self
.
n_words
=
value
@
property
def
hidden_size
(
self
):
return
self
.
emb_dim
...
...
@@ -479,6 +483,7 @@ class XLMModel(XLMPreTrainedModel):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
embeddings
=
self
.
_get_resized_embeddings
(
self
.
embeddings
,
new_num_tokens
)
return
self
.
embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
...
...
@@ -728,10 +733,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
"""
if
self
.
config
.
torchscript
:
self
.
pred_layer
.
proj
.
weight
=
nn
.
Parameter
(
self
.
transformer
.
embeddings
.
weight
.
clone
())
else
:
self
.
pred_layer
.
proj
.
weight
=
self
.
transformer
.
embeddings
.
weight
self
.
_tie_or_clone_weights
(
self
.
pred_layer
.
proj
,
self
.
transformer
.
embeddings
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
...
...
pytorch_transformers/modeling_xlnet.py
View file @
2918b7d2
...
...
@@ -316,6 +316,10 @@ class XLNetConfig(PretrainedConfig):
def
vocab_size
(
self
):
return
self
.
n_token
@
vocab_size
.
setter
def
vocab_size
(
self
,
value
):
self
.
n_token
=
value
@
property
def
hidden_size
(
self
):
return
self
.
d_model
...
...
@@ -660,10 +664,10 @@ class XLNetModel(XLNetPreTrainedModel):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
word_embedding
=
self
.
_get_resized_embeddings
(
self
.
word_embedding
,
new_num_tokens
)
return
self
.
word_embedding
def
_prune_heads
(
self
,
heads_to_prune
):
logger
.
info
(
"Head pruning is not implemented for XLNet"
)
pass
raise
NotImplementedError
def
create_mask
(
self
,
qlen
,
mlen
):
"""
...
...
@@ -987,10 +991,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
"""
if
self
.
config
.
torchscript
:
self
.
lm_loss
.
weight
=
nn
.
Parameter
(
self
.
transformer
.
word_embedding
.
weight
.
clone
())
else
:
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
self
.
_tie_or_clone_weights
(
self
.
lm_loss
,
self
.
transformer
.
word_embedding
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
...
...
pytorch_transformers/tests/modeling_bert_test.py
View file @
2918b7d2
...
...
@@ -26,10 +26,15 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForTokenClassification
,
BertForMultipleChoice
)
from
pytorch_transformers.modeling_bert
import
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
create_and_check_common
s
,
ConfigTester
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCase
s
,
ConfigTester
,
ids_tensor
)
class
BertModelTest
(
unittest
.
TestCase
):
class
BertModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
)
class
BertModelTester
(
object
):
def
__init__
(
self
,
...
...
@@ -55,9 +60,6 @@ class BertModelTest(unittest.TestCase):
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
,
all_model_classes
=
(
BertModel
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
),
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -81,7 +83,6 @@ class BertModelTest(unittest.TestCase):
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -253,52 +254,59 @@ class BertModelTest(unittest.TestCase):
self
.
check_loss_output
(
result
)
def
create_and_check_bert_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
c
re
ate_and_check_commons
(
self
,
config
,
inputs_dict
)
re
turn
config
,
inputs_dict
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
def
setUp
(
self
):
self
.
model_tester
=
BertModelTest
.
BertModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
BertConfig
,
hidden_size
=
37
)
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
BertConfig
,
hidden_size
=
37
)
config_tester
.
run_common_tests
()
self
.
config_tester
.
run_common_tests
()
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
BertModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_model
(
*
config_and_inputs
)
def
test_bert_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_model
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_masked_lm
(
*
config_and_inputs
)
def
test_for_masked_lm
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_masked_lm
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_multiple_choice
(
*
config_and_inputs
)
def
test_for_multiple_choice
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_multiple_choice
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
def
test_for_next_sequence_prediction
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_pretraining
(
*
config_and_inputs
)
def
test_for_pretraining
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_pretraining
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_question_answering
(
*
config_and_inputs
)
def
test_for_question_answering
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_question_answering
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_sequence_classification
(
*
config_and_inputs
)
def
test_for_sequence_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_sequence_classification
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_for_token_classification
(
*
config_and_inputs
)
def
test_for_token_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_token_classification
(
*
config_and_inputs
)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_bert_commons
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
BertModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_transformers/tests/modeling_common_test.py
View file @
2918b7d2
This diff is collapsed.
Click to expand it.
pytorch_transformers/tests/modeling_gpt2_test.py
View file @
2918b7d2
...
...
@@ -16,19 +16,14 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_transformers
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
from
.modeling_common_test
import
(
create_and_check_commons
,
ConfigTester
,
GPTModelTester
)
from
.modeling_common_test
import
CommonTestCases
,
ConfigTester
class
GPT2ModelTest
(
unittest
.
TestCase
):
...
...
@@ -37,14 +32,14 @@ class GPT2ModelTest(unittest.TestCase):
config_tester
.
run_common_tests
()
def
test_model
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
model_tester
=
CommonTestCases
.
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
True
)
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
model_tester
=
CommonTestCases
.
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_slow_tests
()
...
...
pytorch_transformers/tests/modeling_openai_test.py
View file @
2918b7d2
...
...
@@ -19,12 +19,11 @@ from __future__ import print_function
import
unittest
import
pytest
import
torch
from
pytorch_transformers
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.modeling_common_test
import
(
create_and_check_commons
,
ConfigTester
,
GPTModelTester
)
from
.modeling_common_test
import
CommonTestCases
,
ConfigTester
class
OpenAIModelTest
(
unittest
.
TestCase
):
...
...
@@ -33,14 +32,14 @@ class OpenAIModelTest(unittest.TestCase):
config_tester
.
run_common_tests
()
def
test_model
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
model_tester
=
CommonTestCases
.
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
lm_head_model_class
=
OpenAIGPTLMHeadModel
,
double_head_model_class
=
OpenAIGPTDoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
False
)
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
model_tester
=
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
model_tester
=
CommonTestCases
.
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
lm_head_model_class
=
OpenAIGPTLMHeadModel
,
double_head_model_class
=
OpenAIGPTDoubleHeadsModel
)
model_tester
.
run_slow_tests
()
...
...
pytorch_transformers/tests/modeling_transfo_xl_test.py
View file @
2918b7d2
...
...
@@ -28,9 +28,15 @@ import torch
from
pytorch_transformers
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
pytorch_transformers.modeling_transfo_xl
import
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
ConfigTester
,
create_and_check_commons
,
ids_tensor
from
.modeling_common_test
import
ConfigTester
,
CommonTestCases
,
ids_tensor
class
TransfoXLModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
TransfoXLModel
,
TransfoXLLMHeadModel
)
test_pruning
=
False
test_torchscript
=
False
test_resize_embeddings
=
False
class
TransfoXLModelTest
(
unittest
.
TestCase
):
class
TransfoXLModelTester
(
object
):
def
__init__
(
self
,
...
...
@@ -52,7 +58,6 @@ class TransfoXLModelTest(unittest.TestCase):
num_hidden_layers
=
5
,
scope
=
None
,
seed
=
1
,
all_model_classes
=
(
TransfoXLModel
,
TransfoXLLMHeadModel
),
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -73,7 +78,6 @@ class TransfoXLModelTest(unittest.TestCase):
self
.
num_hidden_layers
=
num_hidden_layers
self
.
scope
=
scope
self
.
seed
=
seed
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -171,16 +175,31 @@ class TransfoXLModelTest(unittest.TestCase):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_transfo_xl_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
,
test_torchscript
=
False
)
return
config
,
inputs_dict
def
test_default
(
self
):
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
def
setUp
(
self
):
self
.
model_tester
=
TransfoXLModelTest
.
TransfoXLModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
config_tester
.
run_common_tests
()
self
.
config_tester
.
run_common_tests
()
def
test_transfo_xl_model
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
output_result
=
self
.
model_tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
self
.
model_tester
.
check_transfo_xl_model_output
(
output_result
)
def
test_transfo_xl_lm_head
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
output_result
=
self
.
model_tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
self
.
model_tester
.
check_transfo_xl_lm_head_output
(
output_result
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
...
...
@@ -190,23 +209,6 @@ class TransfoXLModelTest(unittest.TestCase):
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
tester
.
check_transfo_xl_model_output
(
output_result
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_transfo_xl_commons
(
*
config_and_inputs
)
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_transformers/tests/modeling_xlm_test.py
View file @
2918b7d2
...
...
@@ -23,10 +23,15 @@ import pytest
from
pytorch_transformers
import
(
XLMConfig
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_transformers.modeling_xlm
import
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
create_and_check_common
s
,
ConfigTester
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCase
s
,
ConfigTester
,
ids_tensor
)
class
XLMModelTest
(
unittest
.
TestCase
):
class
XLMModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
# , XLMForSequenceClassification, XLMForTokenClassification),
class
XLMModelTester
(
object
):
def
__init__
(
self
,
...
...
@@ -58,8 +63,6 @@ class XLMModelTest(unittest.TestCase):
summary_type
=
"last"
,
use_proj
=
True
,
scope
=
None
,
all_model_classes
=
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
),
# , XLMForSequenceClassification, XLMForTokenClassification),
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -90,7 +93,6 @@ class XLMModelTest(unittest.TestCase):
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -237,28 +239,23 @@ class XLMModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'lengths'
:
input_lengths
}
c
re
ate_and_check_commons
(
self
,
config
,
inputs_dict
)
re
turn
config
,
inputs_dict
def
test_default
(
self
):
self
.
run_tester
(
XLMModelTest
.
XLMModelTester
(
self
))
def
setUp
(
self
):
self
.
model_tester
=
XLMModelTest
.
XLMModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLMConfig
,
emb_dim
=
37
)
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
XLMConfig
,
emb_dim
=
37
)
config_tester
.
run_common_tests
()
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XLMModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
config_tester
.
run_common_tests
()
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlm_model
(
*
config_and_inputs
)
def
test_xlm_model
(
self
):
config_and_inputs
=
self
.
model_
tester
.
prepare_config_and_inputs
()
self
.
model_
tester
.
create_and_check_xlm_model
(
*
config_and_inputs
)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_masked_lm(*config_and_inputs)
...
...
@@ -275,8 +272,14 @@ class XLMModelTest(unittest.TestCase):
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_token_classification(*config_and_inputs)
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlm_commons
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XLMModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_transformers/tests/modeling_xlnet_test.py
View file @
2918b7d2
...
...
@@ -28,9 +28,14 @@ import torch
from
pytorch_transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
from
pytorch_transformers.modeling_xlnet
import
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
ConfigTester
,
create_and_check_commons
,
ids_tensor
from
.modeling_common_test
import
ConfigTester
,
CommonTestCases
,
ids_tensor
class
XLNetModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
test_pruning
=
False
class
XLNetModelTest
(
unittest
.
TestCase
):
class
XLNetModelTester
(
object
):
def
__init__
(
self
,
...
...
@@ -56,8 +61,6 @@ class XLNetModelTest(unittest.TestCase):
initializer_range
=
0.05
,
seed
=
1
,
type_vocab_size
=
2
,
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
),
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -82,7 +85,6 @@ class XLNetModelTest(unittest.TestCase):
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
...
...
@@ -264,17 +266,41 @@ class XLNetModelTest(unittest.TestCase):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
return
config
,
inputs_dict
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
def
setUp
(
self
):
self
.
model_tester
=
XLNetModelTest
.
XLNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
config_tester
.
run_common_tests
()
self
.
config_tester
.
run_common_tests
()
def
test_xlnet_base_model
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
def
test_xlnet_lm_head
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
def
test_xlnet_sequence_classif
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
def
test_xlnet_qa
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_qa
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
...
...
@@ -284,27 +310,6 @@ class XLNetModelTest(unittest.TestCase):
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_qa
(
*
config_and_inputs
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_commons
(
*
config_and_inputs
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment