Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fdcde144
"sgl-kernel/vscode:/vscode.git/clone" did not exist on "37565b7f2164d56d02aca52470812ad967b4d317"
Unverified
Commit
fdcde144
authored
Jan 29, 2021
by
Julien Plu
Committed by
GitHub
Jan 29, 2021
Browse files
Add XLA test (#9848)
parent
99b9affa
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
85 additions
and
0 deletions
+85
-0
tests/test_modeling_tf_bart.py
tests/test_modeling_tf_bart.py
+4
-0
tests/test_modeling_tf_blenderbot.py
tests/test_modeling_tf_blenderbot.py
+4
-0
tests/test_modeling_tf_blenderbot_small.py
tests/test_modeling_tf_blenderbot_small.py
+4
-0
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+13
-0
tests/test_modeling_tf_convbert.py
tests/test_modeling_tf_convbert.py
+4
-0
tests/test_modeling_tf_ctrl.py
tests/test_modeling_tf_ctrl.py
+4
-0
tests/test_modeling_tf_flaubert.py
tests/test_modeling_tf_flaubert.py
+4
-0
tests/test_modeling_tf_gpt2.py
tests/test_modeling_tf_gpt2.py
+4
-0
tests/test_modeling_tf_led.py
tests/test_modeling_tf_led.py
+4
-0
tests/test_modeling_tf_longformer.py
tests/test_modeling_tf_longformer.py
+4
-0
tests/test_modeling_tf_marian.py
tests/test_modeling_tf_marian.py
+4
-0
tests/test_modeling_tf_mbart.py
tests/test_modeling_tf_mbart.py
+4
-0
tests/test_modeling_tf_mpnet.py
tests/test_modeling_tf_mpnet.py
+4
-0
tests/test_modeling_tf_openai.py
tests/test_modeling_tf_openai.py
+4
-0
tests/test_modeling_tf_pegasus.py
tests/test_modeling_tf_pegasus.py
+4
-0
tests/test_modeling_tf_t5.py
tests/test_modeling_tf_t5.py
+8
-0
tests/test_modeling_tf_transfo_xl.py
tests/test_modeling_tf_transfo_xl.py
+4
-0
tests/test_modeling_tf_xlm.py
tests/test_modeling_tf_xlm.py
+4
-0
No files found.
tests/test_modeling_tf_bart.py
View file @
fdcde144
...
...
@@ -281,6 +281,10 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make BART float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make BART XLA compliant
pass
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
tests/test_modeling_tf_blenderbot.py
View file @
fdcde144
...
...
@@ -217,6 +217,10 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Blenderbot float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Blenderbot XLA compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_blenderbot_small.py
View file @
fdcde144
...
...
@@ -282,6 +282,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Blenderbot Small float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Blenderbot Small XLA compliant
pass
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
tests/test_modeling_tf_common.py
View file @
fdcde144
...
...
@@ -141,6 +141,19 @@ class TFModelTesterMixin:
outputs
=
run_in_graph_mode
()
self
.
assertIsNotNone
(
outputs
)
def
test_xla_mode
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
)
@
tf
.
function
(
experimental_compile
=
True
)
def
run_in_graph_mode
():
return
model
(
inputs
)
outputs
=
run_in_graph_mode
()
self
.
assertIsNotNone
(
outputs
)
def
test_forward_signature
(
self
):
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_convbert.py
View file @
fdcde144
...
...
@@ -301,6 +301,10 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
[
self
.
model_tester
.
num_attention_heads
/
2
,
encoder_seq_length
,
encoder_key_length
],
)
def
test_xla_mode
(
self
):
# TODO JP: Make ConvBert XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFConvBertModel
.
from_pretrained
(
"YituTech/conv-bert-base"
)
...
...
tests/test_modeling_tf_ctrl.py
View file @
fdcde144
...
...
@@ -225,6 +225,10 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make CTRL float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make CTRL XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_flaubert.py
View file @
fdcde144
...
...
@@ -334,6 +334,10 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Flaubert float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Flaubert XLA compliant
pass
@
require_tf
@
require_sentencepiece
...
...
tests/test_modeling_tf_gpt2.py
View file @
fdcde144
...
...
@@ -391,6 +391,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make GPT2 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make GPT2 XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_led.py
View file @
fdcde144
...
...
@@ -361,6 +361,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make LED float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make LED XLA compliant
pass
def
test_saved_model_with_attentions_output
(
self
):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
...
...
tests/test_modeling_tf_longformer.py
View file @
fdcde144
...
...
@@ -359,6 +359,10 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Longformer float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Blenderbot XLA compliant
pass
@
require_tf
@
require_sentencepiece
...
...
tests/test_modeling_tf_marian.py
View file @
fdcde144
...
...
@@ -250,6 +250,10 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Marian float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Marian XLA compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_mbart.py
View file @
fdcde144
...
...
@@ -221,6 +221,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make MBart float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make MBart XLA compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_mpnet.py
View file @
fdcde144
...
...
@@ -231,6 +231,10 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_mpnet_for_token_classification
(
*
config_and_inputs
)
def
test_xla_mode
(
self
):
# TODO JP: Make MPNet XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
[
"microsoft/mpnet-base"
]:
...
...
tests/test_modeling_tf_openai.py
View file @
fdcde144
...
...
@@ -249,6 +249,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make OpenAIGPT float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make OpenAIGPT XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_pegasus.py
View file @
fdcde144
...
...
@@ -248,6 +248,10 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make Pegasus float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Pegasus XLA compliant
pass
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_tf_t5.py
View file @
fdcde144
...
...
@@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFT5Model
.
from_pretrained
(
"t5-small"
)
...
...
@@ -443,6 +447,10 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make T5 float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make T5 XLA compliant
pass
@
require_tf
@
require_sentencepiece
...
...
tests/test_modeling_tf_transfo_xl.py
View file @
fdcde144
...
...
@@ -208,6 +208,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make TransfoXL float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make TransfoXL XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
tests/test_modeling_tf_xlm.py
View file @
fdcde144
...
...
@@ -330,6 +330,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make XLM float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make XLM XLA compliant
pass
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment