Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ad0ab9af
Commit
ad0ab9af
authored
Sep 05, 2019
by
thomwolf
Browse files
fix test when tf is not here
parent
59fe641b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
103 additions
and
48 deletions
+103
-48
.circleci/config.yml
.circleci/config.yml
+2
-0
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+62
-32
pytorch_transformers/tests/modeling_tf_bert_test.py
pytorch_transformers/tests/modeling_tf_bert_test.py
+19
-5
pytorch_transformers/tests/modeling_tf_common_test.py
pytorch_transformers/tests/modeling_tf_common_test.py
+20
-11
No files found.
.circleci/config.yml
View file @
ad0ab9af
...
...
@@ -11,6 +11,7 @@ jobs:
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install pytest codecov pytest-cov
-
run
:
sudo pip install tensorboardX scikit-learn
-
run
:
sudo pip install tensorflow==2.0.0-rc0
-
run
:
python -m pytest -sv ./pytorch_transformers/tests/ --cov
-
run
:
python -m pytest -sv ./examples/
-
run
:
codecov
...
...
@@ -24,6 +25,7 @@ jobs:
-
checkout
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install pytest codecov pytest-cov
-
run
:
sudo pip install tensorflow==2.0.0-rc0
-
run
:
python -m pytest -sv ./pytorch_transformers/tests/ --cov
-
run
:
codecov
deploy_doc
:
...
...
pytorch_transformers/__init__.py
View file @
ad0ab9af
__version__
=
"1.2.0"
# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.
# see: https://github.com/abseil/abseil-py/issues/99
...
...
@@ -11,6 +12,10 @@ try:
except
:
pass
import
logging
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
# Tokenizer
from
.tokenization_utils
import
(
PreTrainedTokenizer
)
from
.tokenization_auto
import
AutoTokenizer
...
...
@@ -36,38 +41,63 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
from
.configuration_distilbert
import
DistilBertConfig
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
# Modeling
from
.modeling_utils
import
(
PreTrainedModel
,
prune_layer
,
Conv1D
)
from
.modeling_auto
import
(
AutoModel
,
AutoModelForSequenceClassification
,
AutoModelForQuestionAnswering
,
AutoModelWithLMHead
)
from
.modeling_bert
import
(
BertPreTrainedModel
,
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForSequenceClassification
,
BertForMultipleChoice
,
BertForTokenClassification
,
BertForQuestionAnswering
,
load_tf_weights_in_bert
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_openai
import
(
OpenAIGPTPreTrainedModel
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
,
load_tf_weights_in_openai_gpt
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_transfo_xl
import
(
TransfoXLPreTrainedModel
,
TransfoXLModel
,
TransfoXLLMHeadModel
,
load_tf_weights_in_transfo_xl
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_gpt2
import
(
GPT2PreTrainedModel
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
load_tf_weights_in_gpt2
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_xlnet
import
(
XLNetPreTrainedModel
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
,
load_tf_weights_in_xlnet
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_xlm
import
(
XLMPreTrainedModel
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_roberta
import
(
RobertaForMaskedLM
,
RobertaModel
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
# Optimization
from
.optimization
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
WarmupCosineWithHardRestartsSchedule
,
WarmupLinearSchedule
)
try
:
import
torch
torch_available
=
True
# pylint: disable=invalid-name
except
ImportError
:
torch_available
=
False
# pylint: disable=invalid-name
if
torch_available
:
logger
.
info
(
"PyTorch version {} available."
.
format
(
torch
.
__version__
))
from
.modeling_utils
import
(
PreTrainedModel
,
prune_layer
,
Conv1D
)
from
.modeling_auto
import
(
AutoModel
,
AutoModelForSequenceClassification
,
AutoModelForQuestionAnswering
,
AutoModelWithLMHead
)
from
.modeling_bert
import
(
BertPreTrainedModel
,
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForSequenceClassification
,
BertForMultipleChoice
,
BertForTokenClassification
,
BertForQuestionAnswering
,
load_tf_weights_in_bert
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_openai
import
(
OpenAIGPTPreTrainedModel
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
,
load_tf_weights_in_openai_gpt
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_transfo_xl
import
(
TransfoXLPreTrainedModel
,
TransfoXLModel
,
TransfoXLLMHeadModel
,
load_tf_weights_in_transfo_xl
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_gpt2
import
(
GPT2PreTrainedModel
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
load_tf_weights_in_gpt2
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_xlnet
import
(
XLNetPreTrainedModel
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
,
load_tf_weights_in_xlnet
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_xlm
import
(
XLMPreTrainedModel
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_roberta
import
(
RobertaForMaskedLM
,
RobertaModel
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
# Optimization
from
.optimization
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
WarmupCosineWithHardRestartsSchedule
,
WarmupLinearSchedule
)
# TensorFlow
try
:
import
tensorflow
as
tf
tf_available
=
True
# pylint: disable=invalid-name
except
ImportError
:
tf_available
=
False
# pylint: disable=invalid-name
if
tf_available
:
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
from
.modeling_tf_utils
import
TFPreTrainedModel
from
.modeling_tf_bert
import
(
TFBertPreTrainedModel
,
TFBertModel
,
TFBertForPreTraining
,
TFBertForMaskedLM
,
TFBertForNextSentencePrediction
,
load_pt_weights_in_bert
)
# Files and general utilities
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
...
...
pytorch_transformers/tests/modeling_tf_test.py
→
pytorch_transformers/tests/modeling_tf_
bert_
test.py
View file @
ad0ab9af
...
...
@@ -19,15 +19,19 @@ from __future__ import print_function
import
unittest
import
shutil
import
pytest
import
tensorflow
as
tf
from
pytorch_transformers
import
(
BertConfig
)
from
pytorch_transformers.modeling_tf_bert
import
TFBertModel
,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
try
:
import
tensorflow
as
tf
from
pytorch_transformers
import
(
BertConfig
)
from
pytorch_transformers.modeling_tf_bert
import
TFBertModel
,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
except
ImportError
:
pass
class
TFBertModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
...
...
@@ -283,39 +287,48 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_bert_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_model
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_masked_lm
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_masked_lm
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_multiple_choice
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_multiple_choice
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_next_sequence_prediction
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_pretraining
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_pretraining
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_question_answering
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_question_answering
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_sequence_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_sequence_classification
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_token_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_token_classification
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
...
...
@@ -325,3 +338,4 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_transformers/tests/modeling_tf_common_test.py
View file @
ad0ab9af
...
...
@@ -12,24 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
copy
import
os
import
shutil
import
json
import
logging
import
random
import
uuid
import
shutil
import
unittest
import
logging
import
uuid
import
tensorflow
as
tf
import
pytest
import
sys
from
pytorch_transformers
import
TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
try
:
import
tensorflow
as
tf
from
pytorch_transformers
import
TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
except
ImportError
:
pass
def
_config_zero_init
(
config
):
...
...
@@ -49,6 +50,7 @@ class TFCommonTestCases:
test_pruning
=
True
test_resize_embeddings
=
True
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_initialization
(
self
):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
...
@@ -62,6 +64,7 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_attention_outputs
(
self
):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
...
@@ -102,6 +105,7 @@ class TFCommonTestCases:
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_headmasking
(
self
):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
...
@@ -149,6 +153,7 @@ class TFCommonTestCases:
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_head_pruning
(
self
):
pass
# if not self.test_pruning:
...
...
@@ -176,6 +181,7 @@ class TFCommonTestCases:
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_hidden_states_output
(
self
):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
...
@@ -195,6 +201,7 @@ class TFCommonTestCases:
# [self.model_tester.seq_length, self.model_tester.hidden_size])
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_resize_tokens_embeddings
(
self
):
pass
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
...
@@ -231,6 +238,7 @@ class TFCommonTestCases:
# self.assertTrue(models_equal)
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_tie_model_weights
(
self
):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
...
@@ -282,6 +290,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
class
TFModelUtilsTest
(
unittest
.
TestCase
):
@
pytest
.
mark
.
skipif
(
'tf'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_model_from_pretrained
(
self
):
pass
# logging.basicConfig(level=logging.INFO)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment