Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
623346ab
Unverified
Commit
623346ab
authored
Jan 31, 2023
by
Joao Gante
Committed by
GitHub
Jan 31, 2023
Browse files
Template for framework-agnostic tests (#21348)
parent
5451f889
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
68 additions
and
41 deletions
+68
-41
tests/generation/test_framework_agnostic.py
tests/generation/test_framework_agnostic.py
+41
-0
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+13
-18
tests/generation/test_utils.py
tests/generation/test_utils.py
+11
-21
utils/tests_fetcher.py
utils/tests_fetcher.py
+3
-2
No files found.
tests/generation/test_framework_agnostic.py
0 → 100644
View file @
623346ab
"""
Framework agnostic tests for generate()-related methods.
"""
import
numpy
as
np
from
transformers
import
AutoTokenizer
class
GenerationIntegrationTestsMixin
:
# To be populated by the child classes
framework_dependent_parameters
=
{
"AutoModelForSeq2SeqLM"
:
None
,
"create_tensor_fn"
:
None
,
"return_tensors"
:
None
,
}
def
test_validate_generation_inputs
(
self
):
model_cls
=
self
.
framework_dependent_parameters
[
"AutoModelForSeq2SeqLM"
]
return_tensors
=
self
.
framework_dependent_parameters
[
"return_tensors"
]
create_tensor_fn
=
self
.
framework_dependent_parameters
[
"create_tensor_fn"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
return_tensors
).
input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with
self
.
assertRaisesRegex
(
ValueError
,
"do_samples"
):
model
.
generate
(
input_ids
,
do_samples
=
True
)
# arbitrary arguments that will not be used anywhere are also not accepted
with
self
.
assertRaisesRegex
(
ValueError
,
"foo"
):
fake_model_kwargs
=
{
"foo"
:
"bar"
}
model
.
generate
(
input_ids
,
**
fake_model_kwargs
)
# however, valid model_kwargs are accepted
valid_model_kwargs
=
{
"attention_mask"
:
create_tensor_fn
(
np
.
zeros_like
(
input_ids
))}
model
.
generate
(
input_ids
,
**
valid_model_kwargs
)
tests/generation/test_tf_utils.py
View file @
623346ab
...
@@ -19,11 +19,13 @@ import unittest
...
@@ -19,11 +19,13 @@ import unittest
from
transformers
import
is_tf_available
from
transformers
import
is_tf_available
from
transformers.testing_utils
import
require_tf
,
slow
from
transformers.testing_utils
import
require_tf
,
slow
from
.test_framework_agnostic
import
GenerationIntegrationTestsMixin
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
import
tensorflow
as
tf
from
transformers
import
AutoTokenizer
,
TFAutoModelForCausalLM
,
TFAutoModelForSeq2SeqLM
,
tf_top_k_top_p_filtering
from
transformers
import
TFAutoModelForCausalLM
,
TFAutoModelForSeq2SeqLM
,
tf_top_k_top_p_filtering
@
require_tf
@
require_tf
...
@@ -124,7 +126,16 @@ class UtilsFunctionsTest(unittest.TestCase):
...
@@ -124,7 +126,16 @@ class UtilsFunctionsTest(unittest.TestCase):
@
require_tf
@
require_tf
class
TFGenerationIntegrationTests
(
unittest
.
TestCase
):
class
TFGenerationIntegrationTests
(
unittest
.
TestCase
,
GenerationIntegrationTestsMixin
):
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if
is_tf_available
():
framework_dependent_parameters
=
{
"AutoModelForSeq2SeqLM"
:
TFAutoModelForSeq2SeqLM
,
"create_tensor_fn"
:
tf
.
convert_to_tensor
,
"return_tensors"
:
"tf"
,
}
@
slow
@
slow
def
test_generate_tf_function_export
(
self
):
def
test_generate_tf_function_export
(
self
):
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
...
@@ -165,19 +176,3 @@ class TFGenerationIntegrationTests(unittest.TestCase):
...
@@ -165,19 +176,3 @@ class TFGenerationIntegrationTests(unittest.TestCase):
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_length
)
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_length
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
def
test_validate_generation_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
model
=
TFAutoModelForSeq2SeqLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"tf"
).
input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with
self
.
assertRaisesRegex
(
ValueError
,
"do_samples"
):
model
.
generate
(
input_ids
,
do_samples
=
True
)
# arbitrary arguments that will not be used anywhere are also not accepted
with
self
.
assertRaisesRegex
(
ValueError
,
"foo"
):
fake_model_kwargs
=
{
"foo"
:
"bar"
}
model
.
generate
(
input_ids
,
**
fake_model_kwargs
)
tests/generation/test_utils.py
View file @
623346ab
...
@@ -23,6 +23,7 @@ from transformers import is_torch_available, pipeline
...
@@ -23,6 +23,7 @@ from transformers import is_torch_available, pipeline
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
..test_modeling_common
import
floats_tensor
,
ids_tensor
from
..test_modeling_common
import
floats_tensor
,
ids_tensor
from
.test_framework_agnostic
import
GenerationIntegrationTestsMixin
if
is_torch_available
():
if
is_torch_available
():
...
@@ -1790,7 +1791,16 @@ class UtilsFunctionsTest(unittest.TestCase):
...
@@ -1790,7 +1791,16 @@ class UtilsFunctionsTest(unittest.TestCase):
@
require_torch
@
require_torch
class
GenerationIntegrationTests
(
unittest
.
TestCase
):
class
GenerationIntegrationTests
(
unittest
.
TestCase
,
GenerationIntegrationTestsMixin
):
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if
is_torch_available
():
framework_dependent_parameters
=
{
"AutoModelForSeq2SeqLM"
:
AutoModelForSeq2SeqLM
,
"create_tensor_fn"
:
torch
.
tensor
,
"return_tensors"
:
"pt"
,
}
@
slow
@
slow
def
test_diverse_beam_search
(
self
):
def
test_diverse_beam_search
(
self
):
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood.
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood.
...
@@ -3022,26 +3032,6 @@ class GenerationIntegrationTests(unittest.TestCase):
...
@@ -3022,26 +3032,6 @@ class GenerationIntegrationTests(unittest.TestCase):
max_score_diff
=
(
output_sequences_batched
.
scores
[
0
][
1
]
-
output_sequences
.
scores
[
0
][
0
]).
abs
().
max
()
max_score_diff
=
(
output_sequences_batched
.
scores
[
0
][
1
]
-
output_sequences
.
scores
[
0
][
0
]).
abs
().
max
()
self
.
assertTrue
(
max_score_diff
<
1e-5
)
self
.
assertTrue
(
max_score_diff
<
1e-5
)
def
test_validate_generation_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-roberta"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-roberta"
)
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"pt"
).
input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with
self
.
assertRaisesRegex
(
ValueError
,
"do_samples"
):
model
.
generate
(
input_ids
,
do_samples
=
True
)
# arbitrary arguments that will not be used anywhere are also not accepted
with
self
.
assertRaisesRegex
(
ValueError
,
"foo"
):
fake_model_kwargs
=
{
"foo"
:
"bar"
}
model
.
generate
(
input_ids
,
**
fake_model_kwargs
)
# However, valid model_kwargs are accepted
valid_model_kwargs
=
{
"attention_mask"
:
torch
.
zeros_like
(
input_ids
)}
model
.
generate
(
input_ids
,
**
valid_model_kwargs
)
def
test_eos_token_id_int_and_list_greedy_search
(
self
):
def
test_eos_token_id_int_and_list_greedy_search
(
self
):
generation_kwargs
=
{
generation_kwargs
=
{
"do_sample"
:
False
,
"do_sample"
:
False
,
...
...
utils/tests_fetcher.py
View file @
623346ab
...
@@ -466,12 +466,13 @@ def module_to_test_file(module_fname):
...
@@ -466,12 +466,13 @@ def module_to_test_file(module_fname):
# This list contains the list of test files we expect never to be launched from a change in a module/util. Those are
# This list contains the list of test files we expect never to be launched from a change in a module/util. Those are
# launched separately.
# launched separately.
EXPECTED_TEST_FILES_NEVER_TOUCHED
=
[
EXPECTED_TEST_FILES_NEVER_TOUCHED
=
[
"tests/utils/test_doc_samples.py"
,
# Doc tests
"tests/generation/test_framework_agnostic.py"
,
# Mixins inherited by actual test classes
"tests/mixed_int8/test_mixed_int8.py"
,
# Mixed-int8 bitsandbytes test
"tests/pipelines/test_pipelines_common.py"
,
# Actually checked by the pipeline based file
"tests/pipelines/test_pipelines_common.py"
,
# Actually checked by the pipeline based file
"tests/sagemaker/test_single_node_gpu.py"
,
# SageMaker test
"tests/sagemaker/test_single_node_gpu.py"
,
# SageMaker test
"tests/sagemaker/test_multi_node_model_parallel.py"
,
# SageMaker test
"tests/sagemaker/test_multi_node_model_parallel.py"
,
# SageMaker test
"tests/sagemaker/test_multi_node_data_parallel.py"
,
# SageMaker test
"tests/sagemaker/test_multi_node_data_parallel.py"
,
# SageMaker test
"tests/
mixed_int8/test_mixed_int8.py"
,
# Mixed-int8 bitsandbytes test
"tests/
utils/test_doc_samples.py"
,
# Doc tests
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment