Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9a24e0cf
"docs/source/vscode:/vscode.git/clone" did not exist on "9edf37583411f892cea9ae7d98156c85d7c087b1"
Commit
9a24e0cf
authored
Dec 11, 2019
by
Morgan Funtowicz
Browse files
Refactored qa pipeline argument handling + unittests
parent
63e36007
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
33 deletions
+87
-33
transformers/pipelines.py
transformers/pipelines.py
+55
-32
transformers/tests/pipelines_test.py
transformers/tests/pipelines_test.py
+32
-1
No files found.
transformers/pipelines.py
View file @
9a24e0cf
...
...
@@ -98,14 +98,11 @@ class TextClassificationPipeline(Pipeline):
class
QuestionAnsweringPipeline
(
Pipeline
):
"""
Question Answering pipeling involving Tokenization and Inference.
TODO:
- top-k answers
- return start/end chars
- return score
"""
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]):
super
().
__init__
(
model
,
tokenizer
)
@
classmethod
def
from_config
(
cls
,
model
,
tokenizer
:
PreTrainedTokenizer
,
**
kwargs
):
pass
@
staticmethod
def
create_sample
(
question
:
Union
[
str
,
List
[
str
]],
context
:
Union
[
str
,
List
[
str
]])
->
Union
[
SquadExample
,
List
[
SquadExample
]]:
...
...
@@ -116,6 +113,55 @@ class QuestionAnsweringPipeline(Pipeline):
else
:
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
@
staticmethod
def
handle_args
(
*
inputs
,
**
kwargs
)
->
List
[
SquadExample
]:
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if
inputs
is
not
None
and
len
(
inputs
)
>
1
:
kwargs
[
'X'
]
=
inputs
# Generic compatibility with sklearn and Keras
# Batched data
if
'X'
in
kwargs
or
'data'
in
kwargs
:
data
=
kwargs
[
'X'
]
if
'X'
in
kwargs
else
kwargs
[
'data'
]
if
not
isinstance
(
data
,
list
):
data
=
[
data
]
for
i
,
item
in
enumerate
(
data
):
if
isinstance
(
item
,
dict
):
if
any
(
k
not
in
item
for
k
in
[
'question'
,
'context'
]):
raise
KeyError
(
'You need to provide a dictionary with keys {question:..., context:...}'
)
data
[
i
]
=
QuestionAnsweringPipeline
.
create_sample
(
**
item
)
elif
isinstance
(
item
,
SquadExample
):
continue
else
:
raise
ValueError
(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.
format
(
'X'
if
'X'
in
kwargs
else
'data'
)
)
inputs
=
data
# Tabular input
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
if
isinstance
(
kwargs
[
'question'
],
str
):
kwargs
[
'question'
]
=
[
kwargs
[
'question'
]]
if
isinstance
(
kwargs
[
'context'
],
str
):
kwargs
[
'context'
]
=
[
kwargs
[
'context'
]]
inputs
=
[
QuestionAnsweringPipeline
.
create_sample
(
q
,
c
)
for
q
,
c
in
zip
(
kwargs
[
'question'
],
kwargs
[
'context'
])]
else
:
raise
ValueError
(
'Unknown arguments {}'
.
format
(
kwargs
))
if
not
isinstance
(
inputs
,
list
):
inputs
=
[
inputs
]
return
inputs
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]):
super
().
__init__
(
model
,
tokenizer
)
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
args
=
[
'input_ids'
,
'attention_mask'
]
model_type
=
type
(
self
.
model
).
__name__
.
lower
()
...
...
@@ -131,10 +177,6 @@ class QuestionAnsweringPipeline(Pipeline):
else
:
return
{
k
:
[
feature
.
__dict__
[
k
]
for
feature
in
features
]
for
k
in
args
}
@
classmethod
def
from_config
(
cls
,
model
,
tokenizer
:
PreTrainedTokenizer
,
**
kwargs
):
pass
def
__call__
(
self
,
*
texts
,
**
kwargs
):
# Set defaults values
kwargs
.
setdefault
(
'topk'
,
1
)
...
...
@@ -149,29 +191,10 @@ class QuestionAnsweringPipeline(Pipeline):
if
kwargs
[
'max_answer_len'
]
<
1
:
raise
ValueError
(
'max_answer_len parameter should be >= 1 (got {})'
.
format
(
kwargs
[
'max_answer_len'
]))
# Position args
if
texts
is
not
None
and
len
(
texts
)
>
1
:
(
texts
,
)
=
texts
# Generic compatibility with sklearn and Keras
elif
'X'
in
kwargs
and
not
texts
:
texts
=
kwargs
.
pop
(
'X'
)
# Batched data
elif
'data'
in
kwargs
:
texts
=
kwargs
.
pop
(
'data'
)
# Tabular input
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
texts
=
QuestionAnsweringPipeline
.
create_sample
(
kwargs
[
'question'
],
kwargs
[
'context'
])
else
:
raise
ValueError
(
'Unknown arguments {}'
.
format
(
kwargs
))
if
not
isinstance
(
texts
,
list
):
texts
=
[
texts
]
examples
=
QuestionAnsweringPipeline
.
handle_args
(
texts
,
**
kwargs
)
# Convert inputs to features
features
=
squad_convert_examples_to_features
(
t
ex
t
s
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
features
=
squad_convert_examples_to_features
(
ex
ample
s
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
fw_args
=
self
.
inputs_for_model
(
features
)
if
is_tf_available
():
...
...
@@ -188,7 +211,7 @@ class QuestionAnsweringPipeline(Pipeline):
start
,
end
=
start
.
cpu
().
numpy
(),
end
.
cpu
().
numpy
()
answers
=
[]
for
(
example
,
feature
,
start_
,
end_
)
in
zip
(
t
ex
t
s
,
features
,
start
,
end
):
for
(
example
,
feature
,
start_
,
end_
)
in
zip
(
ex
ample
s
,
features
,
start
,
end
):
# Normalize logits and spans to retrieve the answer
start_
=
np
.
exp
(
start_
)
/
np
.
sum
(
np
.
exp
(
start_
))
end_
=
np
.
exp
(
end_
)
/
np
.
sum
(
np
.
exp
(
end_
))
...
...
transformers/tests/pipelines_test.py
View file @
9a24e0cf
...
...
@@ -40,7 +40,38 @@ class QuestionAnsweringPipelineTest(unittest.TestCase):
# Batch case with topk = 2
a
=
nlp
(
question
=
[
'What is the name of the company I
\'
m working for ?'
,
'Where is the company based ?'
],
context
=
[
'I
\'
m working for Huggingface.'
,
'The company is based in New York and Paris'
],
topk
=
2
)
context
=
[
'Where is the company based ?'
,
'The company is based in New York and Paris'
],
topk
=
2
)
self
.
check_answer_structure
(
a
,
2
,
2
)
# check for data keyword
a
=
nlp
(
data
=
nlp
.
create_sample
(
question
=
'What is the name of the company I
\'
m working for ?'
,
context
=
'I
\'
m working for Huggingface.'
))
self
.
check_answer_structure
(
a
,
1
,
1
)
a
=
nlp
(
data
=
nlp
.
create_sample
(
question
=
'What is the name of the company I
\'
m working for ?'
,
context
=
'I
\'
m working for Huggingface.'
),
topk
=
2
)
self
.
check_answer_structure
(
a
,
1
,
2
)
a
=
nlp
(
data
=
[
nlp
.
create_sample
(
question
=
'What is the name of the company I
\'
m working for ?'
,
context
=
'I
\'
m working for Huggingface.'
),
nlp
.
create_sample
(
question
=
'I
\'
m working for Huggingface.'
,
context
=
'The company is based in New York and Paris'
),
])
self
.
check_answer_structure
(
a
,
2
,
1
)
a
=
nlp
(
data
=
[
{
'question'
:
'What is the name of the company I
\'
m working for ?'
,
'context'
:
'I
\'
m working for Huggingface.'
},
{
'question'
:
'Where is the company based ?'
,
'context'
:
'The company is based in New York and Paris'
},
])
self
.
check_answer_structure
(
a
,
2
,
1
)
# X keywords
a
=
nlp
(
X
=
nlp
.
create_sample
(
question
=
'Where is the company based ?'
,
context
=
'The company is based in New York and Paris'
))
self
.
check_answer_structure
(
a
,
1
,
1
)
a
=
nlp
(
X
=
[
{
'question'
:
'What is the name of the company I
\'
m working for ?'
,
'context'
:
'I
\'
m working for Huggingface.'
},
{
'question'
:
'Where is the company based ?'
,
'context'
:
'The company is based in New York and Paris'
},
],
topk
=
2
)
self
.
check_answer_structure
(
a
,
2
,
2
)
@
patch
(
'transformers.pipelines.is_torch_available'
,
return_value
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment