Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6d90d76f
Unverified
Commit
6d90d76f
authored
Apr 22, 2022
by
Joao Gante
Committed by
GitHub
Apr 22, 2022
Browse files
TF: rework XLA generate tests (#16866)
parent
3b1bbefc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
54 deletions
+76
-54
tests/gpt2/test_modeling_tf_gpt2.py
tests/gpt2/test_modeling_tf_gpt2.py
+47
-31
tests/t5/test_modeling_tf_t5.py
tests/t5/test_modeling_tf_t5.py
+29
-23
No files found.
tests/gpt2/test_modeling_tf_gpt2.py
View file @
6d90d76f
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
unittest
import
unittest
from
transformers
import
GPT2Config
,
is_tf_available
from
transformers
import
GPT2Config
,
is_tf_available
from
transformers.testing_utils
import
require_tf
,
slow
from
transformers.testing_utils
import
get_gpu_count
,
require_tf
,
slow
from
..test_configuration_common
import
ConfigTester
from
..test_configuration_common
import
ConfigTester
from
..test_modeling_tf_common
import
TFModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
from
..test_modeling_tf_common
import
TFModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
...
@@ -294,7 +294,7 @@ class TFGPT2ModelTester:
...
@@ -294,7 +294,7 @@ class TFGPT2ModelTester:
result
=
model
(
inputs
)
result
=
model
(
inputs
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
))
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
))
def
create_and_check_gpt2_xla_generate
(
self
,
config
,
input_ids
,
*
args
):
def
create_and_check_gpt2_xla_generate
_fast
(
self
,
config
,
input_ids
,
*
args
):
config
.
eos_token_id
=
None
config
.
eos_token_id
=
None
config
.
max_length
=
10
config
.
max_length
=
10
model
=
TFGPT2LMHeadModel
(
config
=
config
)
model
=
TFGPT2LMHeadModel
(
config
=
config
)
...
@@ -408,9 +408,9 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
...
@@ -408,9 +408,9 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_lm_head
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_gpt2_lm_head
(
*
config_and_inputs
)
def
test_gpt2_xla_generate
(
self
):
def
test_gpt2_xla_generate
_fast
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_xla_generate
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_gpt2_xla_generate
_fast
(
*
config_and_inputs
)
def
test_gpt2_double_head
(
self
):
def
test_gpt2_double_head
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
...
@@ -536,41 +536,57 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
...
@@ -536,41 +536,57 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
@
slow
@
slow
def
test_lm_generate_gpt2
(
self
):
@
unittest
.
skipIf
(
not
get_gpu_count
(),
"XLA not reliable on CPU"
)
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def
test_lm_generate_gpt2_greedy_xla
(
self
):
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
# the underlying problem)
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
tokenizer
.
pad_token
=
tokenizer
.
eos_token
# fmt: off
tokenizer
.
padding_side
=
"left"
expected_output_ids
=
[
464
,
3290
,
373
,
1043
,
287
,
257
,
2214
,
1474
,
262
,
16246
,
286
,
2688
,
290
,
2688
,
27262
,
13
,
198
,
198
,
464
,
3290
]
# fmt: on
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
slow
sentences
=
[
"The dog"
]
def
test_lm_generate_gpt2_xla_greedy
(
self
):
expected_output_strings
=
[
"""This test gives the exact same results as the non-xla test above"""
"The dog was found in a field near the intersection of West and West Streets.
\n\n
The dog"
,
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
]
input_ids
=
t
f
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
input_ids
=
t
okenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
# fmt: off
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
expected_output_ids
=
[
464
,
3290
,
373
,
1043
,
287
,
257
,
2214
,
1474
,
262
,
16246
,
286
,
2688
,
290
,
2688
,
27262
,
13
,
198
,
198
,
464
,
3290
]
self
.
assertListEqual
(
output_strings
,
expected_output_strings
)
# fmt: on
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
False
)
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
output_strings
,
expected_output_strings
)
@
slow
@
slow
def
test_lm_generate_gpt2_xla_sample
(
self
):
@
unittest
.
skipIf
(
not
get_gpu_count
(),
"XLA not reliable on CPU"
)
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def
test_lm_generate_gpt2_sample_xla
(
self
):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
# and that we can seed both versions.
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
# fmt: off
tokenizer
.
pad_token
=
tokenizer
.
eos_token
expected_output_ids
=
[
464
,
3290
,
550
,
284
,
307
,
4376
,
287
,
281
,
4044
,
1363
,
329
,
734
,
812
,
878
,
852
,
4376
,
757
,
329
,
2267
,
0
]
tokenizer
.
padding_side
=
"left"
# fmt: on
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
sentence
=
[
"The dog"
]
expected_output_string
=
[
"The dog must be well educated to do anything. If anything, this must be her best friend"
]
expected_output_string_xla
=
[
"The dog has been named in connection with the murder of a 20-year-old man in!"
]
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
7
,
0
])
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
7
,
0
])
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
output_strings
,
expected_output_string_xla
)
tests/t5/test_modeling_tf_t5.py
View file @
6d90d76f
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
unittest
import
unittest
from
transformers
import
T5Config
,
is_tf_available
from
transformers
import
T5Config
,
is_tf_available
from
transformers.testing_utils
import
require_sentencepiece
,
require_tf
,
require_tokenizers
,
slow
from
transformers.testing_utils
import
get_gpu_count
,
require_sentencepiece
,
require_tf
,
require_tokenizers
,
slow
from
transformers.utils
import
cached_property
from
transformers.utils
import
cached_property
from
..test_configuration_common
import
ConfigTester
from
..test_configuration_common
import
ConfigTester
...
@@ -227,7 +227,7 @@ class TFT5ModelTester:
...
@@ -227,7 +227,7 @@ class TFT5ModelTester:
# test that outputs are equal for slice
# test that outputs are equal for slice
tf
.
debugging
.
assert_near
(
output_from_past_slice
,
output_from_no_past_slice
,
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
output_from_past_slice
,
output_from_no_past_slice
,
rtol
=
1e-3
)
def
create_and_check_t5_xla_generate
(
self
,
config
,
input_ids
,
*
args
):
def
create_and_check_t5_xla_generate
_fast
(
self
,
config
,
input_ids
,
*
args
):
config
.
eos_token_id
=
None
config
.
eos_token_id
=
None
config
.
max_length
=
10
config
.
max_length
=
10
config
.
do_sample
=
False
config
.
do_sample
=
False
...
@@ -297,9 +297,9 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -297,9 +297,9 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_t5_decoder_model_past_large_inputs
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_t5_decoder_model_past_large_inputs
(
*
config_and_inputs
)
def
test_t5_model_xla_generate
(
self
):
def
test_t5_model_xla_generate
_fast
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_t5_xla_generate
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_t5_xla_generate
_fast
(
*
config_and_inputs
)
def
test_model_common_attributes
(
self
):
def
test_model_common_attributes
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
@@ -481,12 +481,18 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -481,12 +481,18 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
@
require_tokenizers
@
require_tokenizers
class
TFT5GenerationIntegrationTests
(
unittest
.
TestCase
):
class
TFT5GenerationIntegrationTests
(
unittest
.
TestCase
):
@
slow
@
slow
@
unittest
.
skipIf
(
not
get_gpu_count
(),
"XLA not reliable on CPU"
)
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def
test_greedy_xla_generate_simple
(
self
):
def
test_greedy_xla_generate_simple
(
self
):
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
sentence
=
"Translate English to German: Today is a beautiful day."
# two examples with different lengths to confirm that attention masks are operational in XLA
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
sentences
=
[
"Translate English to German: Today is a beautiful day."
,
"Translate English to German: I have four cats, three dogs, two birds, and a horse."
,
]
input_ids
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
...
@@ -496,7 +502,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
...
@@ -496,7 +502,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
expected_output_string
=
[
"Heute ist ein schöner Tag."
]
expected_output_string
=
[
"Heute ist ein schöner Tag."
,
"Ich habe vier Katzen, drei Hunde, zwei Vögel und ein Pferd."
,
]
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings_xla
)
self
.
assertListEqual
(
expected_output_string
,
output_strings_xla
)
...
@@ -525,28 +534,25 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
...
@@ -525,28 +534,25 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
@
slow
@
slow
@
unittest
.
skipIf
(
not
get_gpu_count
(),
"XLA not reliable on CPU"
)
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def
test_sample_xla_generate_simple
(
self
):
def
test_sample_xla_generate_simple
(
self
):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed (unlike this example). We can, however, confirm that the
# results are sensible and that we can seed both versions.
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
sentence
=
"Translate English to German:
Today is a beautiful day.
"
sentence
=
"Translate English to German:
I have two bananas
"
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
expected_output_string
=
[
"Ich habe 2 Bananen"
]
# divergences in generate -- especially with sampling.
expected_output_string_xla
=
[
"Ich habe 2 Bananen"
]
expected_output_string
=
[
"Heute ist ein schöner Tag."
]
expected_output_string_xla
=
[
"Heute ist ein schöne Tage."
]
# However, notice that the first tokens are the same, for the same seed
assert
expected_output_string
[
0
][:
15
]
==
expected_output_string_xla
[
0
][:
15
]
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
# seed set -> deterministic sampling sequence -> deterministic generation
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
# seed set -> deterministic sampling sequence -> deterministic generation
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_ids_xla
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment