Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fdcde144
Unverified
Commit
fdcde144
authored
Jan 29, 2021
by
Julien Plu
Committed by
GitHub
Jan 29, 2021
Browse files
Add XLA test (#9848)
parent
99b9affa
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
85 additions
and
0 deletions
+85
-0
tests/test_modeling_tf_bart.py
tests/test_modeling_tf_bart.py
+4
-0
tests/test_modeling_tf_blenderbot.py
tests/test_modeling_tf_blenderbot.py
+4
-0
tests/test_modeling_tf_blenderbot_small.py
tests/test_modeling_tf_blenderbot_small.py
+4
-0
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+13
-0
tests/test_modeling_tf_convbert.py
tests/test_modeling_tf_convbert.py
+4
-0
tests/test_modeling_tf_ctrl.py
tests/test_modeling_tf_ctrl.py
+4
-0
tests/test_modeling_tf_flaubert.py
tests/test_modeling_tf_flaubert.py
+4
-0
tests/test_modeling_tf_gpt2.py
tests/test_modeling_tf_gpt2.py
+4
-0
tests/test_modeling_tf_led.py
tests/test_modeling_tf_led.py
+4
-0
tests/test_modeling_tf_longformer.py
tests/test_modeling_tf_longformer.py
+4
-0
tests/test_modeling_tf_marian.py
tests/test_modeling_tf_marian.py
+4
-0
tests/test_modeling_tf_mbart.py
tests/test_modeling_tf_mbart.py
+4
-0
tests/test_modeling_tf_mpnet.py
tests/test_modeling_tf_mpnet.py
+4
-0
tests/test_modeling_tf_openai.py
tests/test_modeling_tf_openai.py
+4
-0
tests/test_modeling_tf_pegasus.py
tests/test_modeling_tf_pegasus.py
+4
-0
tests/test_modeling_tf_t5.py
tests/test_modeling_tf_t5.py
+8
-0
tests/test_modeling_tf_transfo_xl.py
tests/test_modeling_tf_transfo_xl.py
+4
-0
tests/test_modeling_tf_xlm.py
tests/test_modeling_tf_xlm.py
+4
-0
No files found.
tests/test_modeling_tf_bart.py
View file @
fdcde144
...
...
@@ -281,6 +281,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make BART float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make BART XLA compliant
pass
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
tests/test_modeling_tf_blenderbot.py
View file @
fdcde144
...
...
@@ -217,6 +217,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Blenderbot float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Blenderbot XLA compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_blenderbot_small.py
View file @
fdcde144
...
...
@@ -282,6 +282,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Blenderbot Small float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Blenderbot Small XLA compliant
pass
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
tests/test_modeling_tf_common.py
View file @
fdcde144
...
...
@@ -141,6 +141,19 @@ class TFModelTesterMixin:
outputs
=
run_in_graph_mode
()
self
.
assertIsNotNone
(
outputs
)
def
test_xla_mode
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
)
@
tf
.
function
(
experimental_compile
=
True
)
def
run_in_graph_mode
():
return
model
(
inputs
)
outputs
=
run_in_graph_mode
()
self
.
assertIsNotNone
(
outputs
)
def
test_forward_signature
(
self
):
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_convbert.py
View file @
fdcde144
...
...
@@ -301,6 +301,10 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
[
self
.
model_tester
.
num_attention_heads
/
2
,
encoder_seq_length
,
encoder_key_length
],
)
def
test_xla_mode
(
self
):
# TODO JP: Make ConvBert XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFConvBertModel
.
from_pretrained
(
"YituTech/conv-bert-base"
)
...
...
tests/test_modeling_tf_ctrl.py
View file @
fdcde144
...
...
@@ -225,6 +225,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make CTRL float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make CTRL XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_flaubert.py
View file @
fdcde144
...
...
@@ -334,6 +334,10 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Flaubert float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Flaubert XLA compliant
pass
@
require_tf
@
require_sentencepiece
...
...
tests/test_modeling_tf_gpt2.py
View file @
fdcde144
...
...
@@ -391,6 +391,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make GPT2 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make GPT2 XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_led.py
View file @
fdcde144
...
...
@@ -361,6 +361,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make LED float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make LED XLA compliant
pass
def
test_saved_model_with_attentions_output
(
self
):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
...
...
tests/test_modeling_tf_longformer.py
View file @
fdcde144
...
...
@@ -359,6 +359,10 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Longformer float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Blenderbot XLA compliant
pass
@
require_tf
@
require_sentencepiece
...
...
tests/test_modeling_tf_marian.py
View file @
fdcde144
...
...
@@ -250,6 +250,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Marian float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Marian XLA compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_mbart.py
View file @
fdcde144
...
...
@@ -221,6 +221,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make MBart float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make MBart XLA compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_mpnet.py
View file @
fdcde144
...
...
@@ -231,6 +231,10 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_mpnet_for_token_classification
(
*
config_and_inputs
)
def
test_xla_mode
(
self
):
# TODO JP: Make MPNet XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
[
"microsoft/mpnet-base"
]:
...
...
tests/test_modeling_tf_openai.py
View file @
fdcde144
...
...
@@ -249,6 +249,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make OpenAIGPT float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make OpenAIGPT XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_pegasus.py
View file @
fdcde144
...
...
@@ -248,6 +248,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Pegasus float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Pegasus XLA compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_t5.py
View file @
fdcde144
...
...
@@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
...
...
@@ -443,6 +447,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
require_tf
@
require_sentencepiece
...
...
tests/test_modeling_tf_transfo_xl.py
View file @
fdcde144
...
...
@@ -208,6 +208,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make TransfoXL float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make TransfoXL XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_xlm.py
View file @
fdcde144
...
...
@@ -330,6 +330,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make XLM float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make XLM XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment