Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6d90d76f
Unverified
Commit
6d90d76f
authored
Apr 22, 2022
by
Joao Gante
Committed by
GitHub
Apr 22, 2022
Browse files
TF: rework XLA generate tests (#16866)
parent
3b1bbefc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
54 deletions
+76
-54
tests/gpt2/test_modeling_tf_gpt2.py
tests/gpt2/test_modeling_tf_gpt2.py
+47
-31
tests/t5/test_modeling_tf_t5.py
tests/t5/test_modeling_tf_t5.py
+29
-23
No files found.
tests/gpt2/test_modeling_tf_gpt2.py
View file @
6d90d76f
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
unittest
import
unittest
from
transformers
import
GPT2Config
,
is_tf_available
from
transformers
import
GPT2Config
,
is_tf_available
from
transformers.testing_utils
import
require_tf
,
slow
from
transformers.testing_utils
import
get_gpu_count
,
require_tf
,
slow
from
..test_configuration_common
import
ConfigTester
from
..test_configuration_common
import
ConfigTester
from
..test_modeling_tf_common
import
TFModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
from
..test_modeling_tf_common
import
TFModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
...
@@ -294,7 +294,7 @@ class TFGPT2ModelTester:
...
@@ -294,7 +294,7 @@ class TFGPT2ModelTester:
result
=
model
(
inputs
)
result
=
model
(
inputs
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
))
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
))
def
create_and_check_gpt2_xla_generate
(
self
,
config
,
input_ids
,
*
args
):
def
create_and_check_gpt2_xla_generate
_fast
(
self
,
config
,
input_ids
,
*
args
):
config
.
eos_token_id
=
None
config
.
eos_token_id
=
None
config
.
max_length
=
10
config
.
max_length
=
10
model
=
TFGPT2LMHeadModel
(
config
=
config
)
model
=
TFGPT2LMHeadModel
(
config
=
config
)
...
@@ -408,9 +408,9 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
...
@@ -408,9 +408,9 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_lm_head
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_gpt2_lm_head
(
*
config_and_inputs
)
def
test_gpt2_xla_generate
(
self
):
def
test_gpt2_xla_generate
_fast
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_xla_generate
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_gpt2_xla_generate
_fast
(
*
config_and_inputs
)
def
test_gpt2_double_head
(
self
):
def
test_gpt2_double_head
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
...
@@ -536,41 +536,57 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
...
@@ -536,41 +536,57 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
@
slow
@
slow
def
test_lm_generate_gpt2
(
self
):
@
unittest
.
skipIf
(
not
get_gpu_count
(),
"XLA not reliable on CPU"
)
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def
test_lm_generate_gpt2_greedy_xla
(
self
):
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
# the underlying problem)
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
tokenizer
.
pad_token
=
tokenizer
.
eos_token
# fmt: off
tokenizer
.
padding_side
=
"left"
expected_output_ids
=
[
464
,
3290
,
373
,
1043
,
287
,
257
,
2214
,
1474
,
262
,
16246
,
286
,
2688
,
290
,
2688
,
27262
,
13
,
198
,
198
,
464
,
3290
]
# fmt: on
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
slow
sentences
=
[
"The dog"
]
def
test_lm_generate_gpt2_xla_greedy
(
self
):
expected_output_strings
=
[
"""This test gives the exact same results as the non-xla test above"""
"The dog was found in a field near the intersection of West and West Streets.
\n\n
The dog"
,
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
]
input_ids
=
t
f
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
input_ids
=
t
okenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
# fmt: off
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
expected_output_ids
=
[
464
,
3290
,
373
,
1043
,
287
,
257
,
2214
,
1474
,
262
,
16246
,
286
,
2688
,
290
,
2688
,
27262
,
13
,
198
,
198
,
464
,
3290
]
self
.
assertListEqual
(
output_strings
,
expected_output_strings
)
# fmt: on
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
False
)
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
output_strings
,
expected_output_strings
)
@
slow
@
slow
def
test_lm_generate_gpt2_xla_sample
(
self
):
@
unittest
.
skipIf
(
not
get_gpu_count
(),
"XLA not reliable on CPU"
)
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def
test_lm_generate_gpt2_sample_xla
(
self
):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
# and that we can seed both versions.
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
model
=
TFGPT2LMHeadModel
.
from_pretrained
(
"gpt2"
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
# fmt: off
tokenizer
.
pad_token
=
tokenizer
.
eos_token
expected_output_ids
=
[
464
,
3290
,
550
,
284
,
307
,
4376
,
287
,
281
,
4044
,
1363
,
329
,
734
,
812
,
878
,
852
,
4376
,
757
,
329
,
2267
,
0
]
tokenizer
.
padding_side
=
"left"
# fmt: on
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
sentence
=
[
"The dog"
]
expected_output_string
=
[
"The dog must be well educated to do anything. If anything, this must be her best friend"
]
expected_output_string_xla
=
[
"The dog has been named in connection with the murder of a 20-year-old man in!"
]
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
7
,
0
])
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
output_strings
,
expected_output_string
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
output_ids
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
7
,
0
])
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
output_strings
,
expected_output_string_xla
)
tests/t5/test_modeling_tf_t5.py
View file @
6d90d76f
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
unittest
import
unittest
from
transformers
import
T5Config
,
is_tf_available
from
transformers
import
T5Config
,
is_tf_available
from
transformers.testing_utils
import
require_sentencepiece
,
require_tf
,
require_tokenizers
,
slow
from
transformers.testing_utils
import
get_gpu_count
,
require_sentencepiece
,
require_tf
,
require_tokenizers
,
slow
from
transformers.utils
import
cached_property
from
transformers.utils
import
cached_property
from
..test_configuration_common
import
ConfigTester
from
..test_configuration_common
import
ConfigTester
...
@@ -227,7 +227,7 @@ class TFT5ModelTester:
...
@@ -227,7 +227,7 @@ class TFT5ModelTester:
# test that outputs are equal for slice
# test that outputs are equal for slice
tf
.
debugging
.
assert_near
(
output_from_past_slice
,
output_from_no_past_slice
,
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
output_from_past_slice
,
output_from_no_past_slice
,
rtol
=
1e-3
)
def
create_and_check_t5_xla_generate
(
self
,
config
,
input_ids
,
*
args
):
def
create_and_check_t5_xla_generate
_fast
(
self
,
config
,
input_ids
,
*
args
):
config
.
eos_token_id
=
None
config
.
eos_token_id
=
None
config
.
max_length
=
10
config
.
max_length
=
10
config
.
do_sample
=
False
config
.
do_sample
=
False
...
@@ -297,9 +297,9 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -297,9 +297,9 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_t5_decoder_model_past_large_inputs
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_t5_decoder_model_past_large_inputs
(
*
config_and_inputs
)
def
test_t5_model_xla_generate
(
self
):
def
test_t5_model_xla_generate
_fast
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_t5_xla_generate
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_t5_xla_generate
_fast
(
*
config_and_inputs
)
def
test_model_common_attributes
(
self
):
def
test_model_common_attributes
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
@@ -481,12 +481,18 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -481,12 +481,18 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
@
require_tokenizers
@
require_tokenizers
class
TFT5GenerationIntegrationTests
(
unittest
.
TestCase
):
class
TFT5GenerationIntegrationTests
(
unittest
.
TestCase
):
@
slow
@
slow
@
unittest
.
skipIf
(
not
get_gpu_count
(),
"XLA not reliable on CPU"
)
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def
test_greedy_xla_generate_simple
(
self
):
def
test_greedy_xla_generate_simple
(
self
):
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
sentence
=
"Translate English to German: Today is a beautiful day."
# two examples with different lengths to confirm that attention masks are operational in XLA
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
sentences
=
[
"Translate English to German: Today is a beautiful day."
,
"Translate English to German: I have four cats, three dogs, two birds, and a horse."
,
]
input_ids
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
...
@@ -496,7 +502,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
...
@@ -496,7 +502,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
expected_output_string
=
[
"Heute ist ein schöner Tag."
]
expected_output_string
=
[
"Heute ist ein schöner Tag."
,
"Ich habe vier Katzen, drei Hunde, zwei Vögel und ein Pferd."
,
]
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings_xla
)
self
.
assertListEqual
(
expected_output_string
,
output_strings_xla
)
...
@@ -525,31 +534,28 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
...
@@ -525,31 +534,28 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
@
slow
@
slow
@
unittest
.
skipIf
(
not
get_gpu_count
(),
"XLA not reliable on CPU"
)
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def
test_sample_xla_generate_simple
(
self
):
def
test_sample_xla_generate_simple
(
self
):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed (unlike this example). We can, however, confirm that the
# results are sensible and that we can seed both versions.
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
sentence
=
"Translate English to German:
Today is a beautiful day.
"
sentence
=
"Translate English to German:
I have two bananas
"
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
input_ids
=
tokenizer
(
sentence
,
return_tensors
=
"tf"
,
padding
=
True
).
input_ids
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
expected_output_string
=
[
"Ich habe 2 Bananen"
]
# divergences in generate -- especially with sampling.
expected_output_string_xla
=
[
"Ich habe 2 Bananen"
]
expected_output_string
=
[
"Heute ist ein schöner Tag."
]
expected_output_string_xla
=
[
"Heute ist ein schöne Tage."
]
# However, notice that the first tokens are the same, for the same seed
assert
expected_output_string
[
0
][:
15
]
==
expected_output_string_xla
[
0
][:
15
]
# forces the generation to happen on CPU, to avoid GPU-related quirks
# seed set -> deterministic sampling sequence -> deterministic generation
with
tf
.
device
(
":/CPU:0"
):
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
output_strings
=
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
self
.
assertListEqual
(
expected_output_string
,
output_strings
)
# forces the generation to happen on CPU, to avoid GPU-related quirks
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
with
tf
.
device
(
":/CPU:0"
):
# seed set -> deterministic sampling sequence -> deterministic generation
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
output_ids_xla
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla
=
xla_generate
(
input_ids
,
do_sample
=
True
,
seed
=
[
42
,
0
])
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
output_strings_xla
=
tokenizer
.
batch_decode
(
output_ids_xla
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_string_xla
,
output_strings_xla
)
self
.
assertListEqual
(
expected_output_string_xla
,
output_strings_xla
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment