Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2c891c15
Unverified
Commit
2c891c15
authored
Jan 27, 2021
by
Julien Plu
Committed by
GitHub
Jan 27, 2021
Browse files
Add a test for mixed precision (#9806)
parent
d5b40d66
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
105 additions
and
8 deletions
+105
-8
tests/test_modeling_tf_albert.py
tests/test_modeling_tf_albert.py
+4
-0
tests/test_modeling_tf_bart.py
tests/test_modeling_tf_bart.py
+4
-0
tests/test_modeling_tf_blenderbot.py
tests/test_modeling_tf_blenderbot.py
+4
-0
tests/test_modeling_tf_blenderbot_small.py
tests/test_modeling_tf_blenderbot_small.py
+4
-0
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+14
-0
tests/test_modeling_tf_ctrl.py
tests/test_modeling_tf_ctrl.py
+4
-0
tests/test_modeling_tf_flaubert.py
tests/test_modeling_tf_flaubert.py
+2
-6
tests/test_modeling_tf_funnel.py
tests/test_modeling_tf_funnel.py
+8
-0
tests/test_modeling_tf_gpt2.py
tests/test_modeling_tf_gpt2.py
+4
-0
tests/test_modeling_tf_led.py
tests/test_modeling_tf_led.py
+4
-0
tests/test_modeling_tf_longformer.py
tests/test_modeling_tf_longformer.py
+13
-2
tests/test_modeling_tf_lxmert.py
tests/test_modeling_tf_lxmert.py
+4
-0
tests/test_modeling_tf_marian.py
tests/test_modeling_tf_marian.py
+4
-0
tests/test_modeling_tf_mbart.py
tests/test_modeling_tf_mbart.py
+4
-0
tests/test_modeling_tf_mobilebert.py
tests/test_modeling_tf_mobilebert.py
+4
-0
tests/test_modeling_tf_openai.py
tests/test_modeling_tf_openai.py
+4
-0
tests/test_modeling_tf_pegasus.py
tests/test_modeling_tf_pegasus.py
+4
-0
tests/test_modeling_tf_t5.py
tests/test_modeling_tf_t5.py
+8
-0
tests/test_modeling_tf_transfo_xl.py
tests/test_modeling_tf_transfo_xl.py
+4
-0
tests/test_modeling_tf_xlm.py
tests/test_modeling_tf_xlm.py
+4
-0
No files found.
tests/test_modeling_tf_albert.py
View file @
2c891c15
...
@@ -294,6 +294,10 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -294,6 +294,10 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
name
=
model
.
get_bias
()
name
=
model
.
get_bias
()
assert
name
is
None
assert
name
is
None
def
test_mixed_precision
(
self
):
# TODO JP: Make ALBERT float16 compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_bart.py
View file @
2c891c15
...
@@ -278,6 +278,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -278,6 +278,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make BART float16 compliant
pass
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
tests/test_modeling_tf_blenderbot.py
View file @
2c891c15
...
@@ -214,6 +214,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -214,6 +214,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Blenderbot float16 compliant
pass
def
test_resize_token_embeddings
(
self
):
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_blenderbot_small.py
View file @
2c891c15
...
@@ -279,6 +279,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -279,6 +279,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Blenderbot Small float16 compliant
pass
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
tests/test_modeling_tf_common.py
View file @
2c891c15
...
@@ -279,6 +279,20 @@ class TFModelTesterMixin:
...
@@ -279,6 +279,20 @@ class TFModelTesterMixin:
[
self
.
model_tester
.
num_attention_heads
,
encoder_seq_length
,
encoder_key_length
],
[
self
.
model_tester
.
num_attention_heads
,
encoder_seq_length
,
encoder_key_length
],
)
)
def
test_mixed_precision
(
self
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
"mixed_float16"
)
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
class_inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
)
outputs
=
model
(
class_inputs_dict
)
self
.
assertIsNotNone
(
outputs
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
"float32"
)
def
test_keras_save_load
(
self
):
def
test_keras_save_load
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_ctrl.py
View file @
2c891c15
...
@@ -221,6 +221,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -221,6 +221,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
name
=
model
.
get_bias
()
name
=
model
.
get_bias
()
assert
name
is
None
assert
name
is
None
def
test_mixed_precision
(
self
):
# TODO JP: Make CTRL float16 compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_flaubert.py
View file @
2c891c15
...
@@ -330,12 +330,8 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -330,12 +330,8 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
model
=
TFFlaubertModel
.
from_pretrained
(
model_name
)
model
=
TFFlaubertModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
def
test_saved_model_with_hidden_states_output
(
self
):
def
test_mixed_precision
(
self
):
# Should be uncommented during patrick TF refactor
# TODO JP: Make Flaubert float16 compliant
pass
def
test_saved_model_with_attentions_output
(
self
):
# Should be uncommented during patrick TF refactor
pass
pass
...
...
tests/test_modeling_tf_funnel.py
View file @
2c891c15
...
@@ -371,6 +371,10 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -371,6 +371,10 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Funnel float16 compliant
pass
@
require_tf
@
require_tf
class
TFFunnelBaseModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
class
TFFunnelBaseModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -401,3 +405,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -401,3 +405,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_saved_model_creation
(
self
):
def
test_saved_model_creation
(
self
):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Funnel float16 compliant
pass
tests/test_modeling_tf_gpt2.py
View file @
2c891c15
...
@@ -387,6 +387,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -387,6 +387,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_for_sequence_classification
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_gpt2_for_sequence_classification
(
*
config_and_inputs
)
def
test_mixed_precision
(
self
):
# TODO JP: Make GPT2 float16 compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_led.py
View file @
2c891c15
...
@@ -357,6 +357,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -357,6 +357,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make LED float16 compliant
pass
def
test_saved_model_with_attentions_output
(
self
):
def
test_saved_model_with_attentions_output
(
self
):
# This test don't pass because of the error:
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
...
...
tests/test_modeling_tf_longformer.py
View file @
2c891c15
...
@@ -340,14 +340,25 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -340,14 +340,25 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
@
slow
@
slow
def
test_saved_model_with_attentions_output
(
self
):
def
test_saved_model_with_attentions_output
(
self
):
# longformer has special attentions which are not
# This test don't pass because of the error:
# compatible in graph mode
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
pass
pass
def
test_saved_model_creation
(
self
):
def
test_saved_model_creation
(
self
):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Longformer float16 compliant
pass
@
require_tf
@
require_tf
@
require_sentencepiece
@
require_sentencepiece
...
...
tests/test_modeling_tf_lxmert.py
View file @
2c891c15
...
@@ -704,6 +704,10 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -704,6 +704,10 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Lxmert float16 compliant
pass
@
slow
@
slow
def
test_saved_model_with_hidden_states_output
(
self
):
def
test_saved_model_with_hidden_states_output
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_marian.py
View file @
2c891c15
...
@@ -247,6 +247,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -247,6 +247,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Marian float16 compliant
pass
def
test_resize_token_embeddings
(
self
):
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_mbart.py
View file @
2c891c15
...
@@ -218,6 +218,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -218,6 +218,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make MBart float16 compliant
pass
def
test_resize_token_embeddings
(
self
):
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_mobilebert.py
View file @
2c891c15
...
@@ -309,6 +309,10 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -309,6 +309,10 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make MobileBert float16 compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
# for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
# for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...
...
tests/test_modeling_tf_openai.py
View file @
2c891c15
...
@@ -245,6 +245,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -245,6 +245,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_openai_gpt_for_sequence_classification
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_openai_gpt_for_sequence_classification
(
*
config_and_inputs
)
def
test_mixed_precision
(
self
):
# TODO JP: Make OpenAIGPT float16 compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_pegasus.py
View file @
2c891c15
...
@@ -245,6 +245,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -245,6 +245,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Pegasus float16 compliant
pass
def
test_resize_token_embeddings
(
self
):
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_t5.py
View file @
2c891c15
...
@@ -306,6 +306,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -306,6 +306,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
...
@@ -435,6 +439,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -435,6 +439,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_train_pipeline_custom_model
(
self
):
def
test_train_pipeline_custom_model
(
self
):
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make T5 float16 compliant
pass
@
require_tf
@
require_tf
@
require_sentencepiece
@
require_sentencepiece
...
...
tests/test_modeling_tf_transfo_xl.py
View file @
2c891c15
...
@@ -204,6 +204,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -204,6 +204,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
name
=
model
.
get_bias
()
name
=
model
.
get_bias
()
assert
name
is
None
assert
name
is
None
def
test_mixed_precision
(
self
):
# TODO JP: Make TransfoXL float16 compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_xlm.py
View file @
2c891c15
...
@@ -326,6 +326,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -326,6 +326,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlm_for_multiple_choice
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_xlm_for_multiple_choice
(
*
config_and_inputs
)
def
test_mixed_precision
(
self
):
# TODO JP: Make XLM float16 compliant
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment