Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
aa4c8804
Commit
aa4c8804
authored
Sep 05, 2019
by
thomwolf
Browse files
skipping tf tests if tf is not installed
parent
134847db
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
17 deletions
+17
-17
pytorch_transformers/tests/modeling_tf_bert_test.py
pytorch_transformers/tests/modeling_tf_bert_test.py
+9
-9
pytorch_transformers/tests/modeling_tf_common_test.py
pytorch_transformers/tests/modeling_tf_common_test.py
+8
-8
No files found.
pytorch_transformers/tests/modeling_tf_bert_test.py
View file @
aa4c8804
...
@@ -287,48 +287,48 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -287,48 +287,48 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def
test_config
(
self
):
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
self
.
config_tester
.
run_common_tests
()
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_bert_model
(
self
):
def
test_bert_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_model
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_bert_model
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_masked_lm
(
self
):
def
test_for_masked_lm
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_masked_lm
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_bert_for_masked_lm
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_multiple_choice
(
self
):
def
test_for_multiple_choice
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_multiple_choice
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_bert_for_multiple_choice
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_next_sequence_prediction
(
self
):
def
test_for_next_sequence_prediction
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_bert_for_next_sequence_prediction
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_pretraining
(
self
):
def
test_for_pretraining
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_pretraining
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_bert_for_pretraining
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_question_answering
(
self
):
def
test_for_question_answering
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_question_answering
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_bert_for_question_answering
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_sequence_classification
(
self
):
def
test_for_sequence_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_sequence_classification
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_bert_for_sequence_classification
(
*
config_and_inputs
)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_for_token_classification
(
self
):
def
test_for_token_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_bert_for_token_classification
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_bert_for_token_classification
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
...
...
pytorch_transformers/tests/modeling_tf_common_test.py
View file @
aa4c8804
...
@@ -50,7 +50,7 @@ class TFCommonTestCases:
...
@@ -50,7 +50,7 @@ class TFCommonTestCases:
test_pruning
=
True
test_pruning
=
True
test_resize_embeddings
=
True
test_resize_embeddings
=
True
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_initialization
(
self
):
def
test_initialization
(
self
):
pass
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
@@ -64,7 +64,7 @@ class TFCommonTestCases:
...
@@ -64,7 +64,7 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_attention_outputs
(
self
):
def
test_attention_outputs
(
self
):
pass
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
@@ -105,7 +105,7 @@ class TFCommonTestCases:
...
@@ -105,7 +105,7 @@ class TFCommonTestCases:
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_headmasking
(
self
):
def
test_headmasking
(
self
):
pass
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
@@ -153,7 +153,7 @@ class TFCommonTestCases:
...
@@ -153,7 +153,7 @@ class TFCommonTestCases:
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_head_pruning
(
self
):
def
test_head_pruning
(
self
):
pass
pass
# if not self.test_pruning:
# if not self.test_pruning:
...
@@ -181,7 +181,7 @@ class TFCommonTestCases:
...
@@ -181,7 +181,7 @@ class TFCommonTestCases:
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_hidden_states_output
(
self
):
def
test_hidden_states_output
(
self
):
pass
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
@@ -201,7 +201,7 @@ class TFCommonTestCases:
...
@@ -201,7 +201,7 @@ class TFCommonTestCases:
# [self.model_tester.seq_length, self.model_tester.hidden_size])
# [self.model_tester.seq_length, self.model_tester.hidden_size])
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_resize_tokens_embeddings
(
self
):
def
test_resize_tokens_embeddings
(
self
):
pass
pass
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
@@ -238,7 +238,7 @@ class TFCommonTestCases:
...
@@ -238,7 +238,7 @@ class TFCommonTestCases:
# self.assertTrue(models_equal)
# self.assertTrue(models_equal)
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_tie_model_weights
(
self
):
def
test_tie_model_weights
(
self
):
pass
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
@@ -290,7 +290,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
...
@@ -290,7 +290,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
class
TFModelUtilsTest
(
unittest
.
TestCase
):
class
TFModelUtilsTest
(
unittest
.
TestCase
):
@
pytest
.
mark
.
skipif
(
't
f
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
@
pytest
.
mark
.
skipif
(
't
ensorflow
'
not
in
sys
.
modules
,
reason
=
"requires TensorFlow"
)
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
pass
pass
# logging.basicConfig(level=logging.INFO)
# logging.basicConfig(level=logging.INFO)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment