Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2c891c15
Unverified
Commit
2c891c15
authored
Jan 27, 2021
by
Julien Plu
Committed by
GitHub
Jan 27, 2021
Browse files
Add a test for mixed precision (#9806)
parent
d5b40d66
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
105 additions
and
8 deletions
+105
-8
tests/test_modeling_tf_albert.py
tests/test_modeling_tf_albert.py
+4
-0
tests/test_modeling_tf_bart.py
tests/test_modeling_tf_bart.py
+4
-0
tests/test_modeling_tf_blenderbot.py
tests/test_modeling_tf_blenderbot.py
+4
-0
tests/test_modeling_tf_blenderbot_small.py
tests/test_modeling_tf_blenderbot_small.py
+4
-0
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+14
-0
tests/test_modeling_tf_ctrl.py
tests/test_modeling_tf_ctrl.py
+4
-0
tests/test_modeling_tf_flaubert.py
tests/test_modeling_tf_flaubert.py
+2
-6
tests/test_modeling_tf_funnel.py
tests/test_modeling_tf_funnel.py
+8
-0
tests/test_modeling_tf_gpt2.py
tests/test_modeling_tf_gpt2.py
+4
-0
tests/test_modeling_tf_led.py
tests/test_modeling_tf_led.py
+4
-0
tests/test_modeling_tf_longformer.py
tests/test_modeling_tf_longformer.py
+13
-2
tests/test_modeling_tf_lxmert.py
tests/test_modeling_tf_lxmert.py
+4
-0
tests/test_modeling_tf_marian.py
tests/test_modeling_tf_marian.py
+4
-0
tests/test_modeling_tf_mbart.py
tests/test_modeling_tf_mbart.py
+4
-0
tests/test_modeling_tf_mobilebert.py
tests/test_modeling_tf_mobilebert.py
+4
-0
tests/test_modeling_tf_openai.py
tests/test_modeling_tf_openai.py
+4
-0
tests/test_modeling_tf_pegasus.py
tests/test_modeling_tf_pegasus.py
+4
-0
tests/test_modeling_tf_t5.py
tests/test_modeling_tf_t5.py
+8
-0
tests/test_modeling_tf_transfo_xl.py
tests/test_modeling_tf_transfo_xl.py
+4
-0
tests/test_modeling_tf_xlm.py
tests/test_modeling_tf_xlm.py
+4
-0
No files found.
tests/test_modeling_tf_albert.py
View file @
2c891c15
...
...
@@ -294,6 +294,10 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
name
=
model
.
get_bias
()
assert
name
is
None
def
test_mixed_precision
(
self
):
# TODO JP: Make ALBERT float16 compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_bart.py
View file @
2c891c15
...
...
@@ -278,6 +278,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make BART float16 compliant
pass
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
tests/test_modeling_tf_blenderbot.py
View file @
2c891c15
...
...
@@ -214,6 +214,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Blenderbot float16 compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_blenderbot_small.py
View file @
2c891c15
...
...
@@ -279,6 +279,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Blenderbot Small float16 compliant
pass
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
tests/test_modeling_tf_common.py
View file @
2c891c15
...
...
@@ -279,6 +279,20 @@ class TFModelTesterMixin:
[
self
.
model_tester
.
num_attention_heads
,
encoder_seq_length
,
encoder_key_length
],
)
def
test_mixed_precision
(
self
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
"mixed_float16"
)
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
class_inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
)
outputs
=
model
(
class_inputs_dict
)
self
.
assertIsNotNone
(
outputs
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
"float32"
)
def
test_keras_save_load
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_ctrl.py
View file @
2c891c15
...
...
@@ -221,6 +221,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
name
=
model
.
get_bias
()
assert
name
is
None
def
test_mixed_precision
(
self
):
# TODO JP: Make CTRL float16 compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_flaubert.py
View file @
2c891c15
...
...
@@ -330,12 +330,8 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
model
=
TFFlaubertModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
def
test_saved_model_with_hidden_states_output
(
self
):
# Should be uncommented during patrick TF refactor
pass
def
test_saved_model_with_attentions_output
(
self
):
# Should be uncommented during patrick TF refactor
def
test_mixed_precision
(
self
):
# TODO JP: Make Flaubert float16 compliant
pass
...
...
tests/test_modeling_tf_funnel.py
View file @
2c891c15
...
...
@@ -371,6 +371,10 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Funnel float16 compliant
pass
@
require_tf
class
TFFunnelBaseModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -401,3 +405,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_saved_model_creation
(
self
):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Funnel float16 compliant
pass
tests/test_modeling_tf_gpt2.py
View file @
2c891c15
...
...
@@ -387,6 +387,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_for_sequence_classification
(
*
config_and_inputs
)
def
test_mixed_precision
(
self
):
# TODO JP: Make GPT2 float16 compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_led.py
View file @
2c891c15
...
...
@@ -357,6 +357,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make LED float16 compliant
pass
def
test_saved_model_with_attentions_output
(
self
):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
...
...
tests/test_modeling_tf_longformer.py
View file @
2c891c15
...
...
@@ -340,14 +340,25 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
@
slow
def
test_saved_model_with_attentions_output
(
self
):
# longformer has special attentions which are not
# compatible in graph mode
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
pass
def
test_saved_model_creation
(
self
):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Longformer float16 compliant
pass
@
require_tf
@
require_sentencepiece
...
...
tests/test_modeling_tf_lxmert.py
View file @
2c891c15
...
...
@@ -704,6 +704,10 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Lxmert float16 compliant
pass
@
slow
def
test_saved_model_with_hidden_states_output
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_marian.py
View file @
2c891c15
...
...
@@ -247,6 +247,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Marian float16 compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_mbart.py
View file @
2c891c15
...
...
@@ -218,6 +218,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make MBart float16 compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_mobilebert.py
View file @
2c891c15
...
...
@@ -309,6 +309,10 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make MobileBert float16 compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
# for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...
...
tests/test_modeling_tf_openai.py
View file @
2c891c15
...
...
@@ -245,6 +245,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_openai_gpt_for_sequence_classification
(
*
config_and_inputs
)
def
test_mixed_precision
(
self
):
# TODO JP: Make OpenAIGPT float16 compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_pegasus.py
View file @
2c891c15
...
...
@@ -245,6 +245,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Pegasus float16 compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_t5.py
View file @
2c891c15
...
...
@@ -306,6 +306,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
...
...
@@ -435,6 +439,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_train_pipeline_custom_model
(
self
):
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
@
require_tf
@
require_sentencepiece
...
...
tests/test_modeling_tf_transfo_xl.py
View file @
2c891c15
...
...
@@ -204,6 +204,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
name
=
model
.
get_bias
()
assert
name
is
None
def
test_mixed_precision
(
self
):
# TODO JP: Make TransfoXL float16 compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_xlm.py
View file @
2c891c15
...
...
@@ -326,6 +326,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlm_for_multiple_choice
(
*
config_and_inputs
)
def
test_mixed_precision
(
self
):
# TODO JP: Make XLM float16 compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment