Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
aae74065
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4965aee064f7aaf380269e88b2dd650867fb2199"
Commit
aae74065
authored
Dec 09, 2019
by
Morgan Funtowicz
Browse files
Added QuestionAnsweringPipeline unit tests.
parent
a7d3794a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
83 additions
and
0 deletions
+83
-0
transformers/tests/pipelines_test.py
transformers/tests/pipelines_test.py
+83
-0
No files found.
transformers/tests/pipelines_test.py
0 → 100644
View file @
aae74065
import
unittest
from
unittest.mock
import
patch
QA_FINETUNED_MODELS
=
{
'bert-large-uncased-whole-word-masking-finetuned-squad'
,
'bert-large-cased-whole-word-masking-finetuned-squad'
,
'distilbert-base-uncased-distilled-squad'
,
}
class
QuestionAnsweringPipelineTest
(
unittest
.
TestCase
):
def
check_answer_structure
(
self
,
answer
,
batch
,
topk
):
self
.
assertIsInstance
(
answer
,
list
)
self
.
assertEqual
(
len
(
answer
),
batch
)
self
.
assertIsInstance
(
answer
[
0
],
list
)
self
.
assertEqual
(
len
(
answer
[
0
]),
topk
)
self
.
assertIsInstance
(
answer
[
0
][
0
],
dict
)
for
item
in
answer
[
0
]:
self
.
assertTrue
(
'start'
in
item
)
self
.
assertTrue
(
'end'
in
item
)
self
.
assertTrue
(
'score'
in
item
)
self
.
assertTrue
(
'answer'
in
item
)
def
question_answering_pipeline
(
self
,
nlp
):
# Simple case with topk = 1, no batching
a
=
nlp
(
question
=
'What is the name of the company I
\'
m working for ?'
,
context
=
'I
\'
m working for Huggingface.'
)
self
.
check_answer_structure
(
a
,
1
,
1
)
# Simple case with topk = 2, no batching
a
=
nlp
(
question
=
'What is the name of the company I
\'
m working for ?'
,
context
=
'I
\'
m working for Huggingface.'
,
topk
=
2
)
self
.
check_answer_structure
(
a
,
1
,
2
)
# Batch case with topk = 1
a
=
nlp
(
question
=
[
'What is the name of the company I
\'
m working for ?'
,
'Where is the company based ?'
],
context
=
[
'I
\'
m working for Huggingface.'
,
'The company is based in New York and Paris'
])
self
.
check_answer_structure
(
a
,
2
,
1
)
# Batch case with topk = 2
a
=
nlp
(
question
=
[
'What is the name of the company I
\'
m working for ?'
,
'Where is the company based ?'
],
context
=
[
'I
\'
m working for Huggingface.'
,
'The company is based in New York and Paris'
],
topk
=
2
)
self
.
check_answer_structure
(
a
,
2
,
2
)
@
patch
(
'transformers.pipelines.is_torch_available'
,
return_value
=
False
)
def
test_tf_models
(
self
,
is_torch_available
):
from
transformers
import
pipeline
for
model
in
QA_FINETUNED_MODELS
:
self
.
question_answering_pipeline
(
pipeline
(
'question-answering'
,
model
))
@
patch
(
'transformers.pipelines.is_tf_available'
,
return_value
=
False
)
@
patch
(
'transformers.tokenization_utils.is_tf_available'
,
return_value
=
False
)
def
test_torch_models
(
self
,
is_tf_available
,
_
):
from
transformers
import
pipeline
for
model
in
QA_FINETUNED_MODELS
:
self
.
question_answering_pipeline
(
pipeline
(
'question-answering'
,
model
))
class
AutoPipelineTest
(
unittest
.
TestCase
):
@
patch
(
'transformers.pipelines.is_torch_available'
,
return_value
=
False
)
def
test_tf_qa
(
self
,
is_torch_available
):
from
transformers
import
pipeline
from
transformers.pipelines
import
QuestionAnsweringPipeline
from
transformers.modeling_tf_utils
import
TFPreTrainedModel
for
model
in
QA_FINETUNED_MODELS
:
nlp
=
pipeline
(
'question-answering'
,
model
)
self
.
assertIsInstance
(
nlp
,
QuestionAnsweringPipeline
)
self
.
assertIsInstance
(
nlp
.
model
,
TFPreTrainedModel
)
@
patch
(
'transformers.pipelines.is_tf_available'
,
return_value
=
False
)
def
test_torch_qa
(
self
,
is_tf_available
):
from
transformers
import
pipeline
from
transformers.pipelines
import
QuestionAnsweringPipeline
from
transformers.modeling_utils
import
PreTrainedModel
for
model
in
QA_FINETUNED_MODELS
:
nlp
=
pipeline
(
'question-answering'
,
model
)
self
.
assertIsInstance
(
nlp
,
QuestionAnsweringPipeline
)
self
.
assertIsInstance
(
nlp
.
model
,
PreTrainedModel
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment