Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0c88c856
Commit
0c88c856
authored
Dec 18, 2019
by
Morgan Funtowicz
Browse files
Unnest QuestionAnsweringArgumentHandler
parent
e347725d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
2 deletions
+60
-2
transformers/pipelines.py
transformers/pipelines.py
+60
-2
No files found.
transformers/pipelines.py
View file @
0c88c856
...
@@ -333,6 +333,63 @@ class NerPipeline(Pipeline):
...
@@ -333,6 +333,63 @@ class NerPipeline(Pipeline):
return
answers
return
answers
class
QuestionAnsweringArgumentHandler
(
ArgumentHandler
):
"""
QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped
to internal SquadExample / SquadFeature structures.
QuestionAnsweringArgumentHandler manages all the possible to create SquadExample from the command-line supplied
arguments.
"""
def
__call__
(
self
,
*
args
,
**
kwargs
):
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if
args
is
not
None
and
len
(
args
)
>
0
:
if
len
(
args
)
==
1
:
kwargs
[
'X'
]
=
args
[
0
]
else
:
kwargs
[
'X'
]
=
list
(
args
)
# Generic compatibility with sklearn and Keras
# Batched data
if
'X'
in
kwargs
or
'data'
in
kwargs
:
data
=
kwargs
[
'X'
]
if
'X'
in
kwargs
else
kwargs
[
'data'
]
if
not
isinstance
(
data
,
list
):
data
=
[
data
]
for
i
,
item
in
enumerate
(
data
):
if
isinstance
(
item
,
dict
):
if
any
(
k
not
in
item
for
k
in
[
'question'
,
'context'
]):
raise
KeyError
(
'You need to provide a dictionary with keys {question:..., context:...}'
)
data
[
i
]
=
QuestionAnsweringPipeline
.
create_sample
(
**
item
)
elif
isinstance
(
item
,
SquadExample
):
continue
else
:
raise
ValueError
(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.
format
(
'X'
if
'X'
in
kwargs
else
'data'
)
)
inputs
=
data
# Tabular input
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
if
isinstance
(
kwargs
[
'question'
],
str
):
kwargs
[
'question'
]
=
[
kwargs
[
'question'
]]
if
isinstance
(
kwargs
[
'context'
],
str
):
kwargs
[
'context'
]
=
[
kwargs
[
'context'
]]
inputs
=
[
QuestionAnsweringPipeline
.
create_sample
(
q
,
c
)
for
q
,
c
in
zip
(
kwargs
[
'question'
],
kwargs
[
'context'
])]
else
:
raise
ValueError
(
'Unknown arguments {}'
.
format
(
kwargs
))
if
not
isinstance
(
inputs
,
list
):
inputs
=
[
inputs
]
return
inputs
class
QuestionAnsweringPipeline
(
Pipeline
):
class
QuestionAnsweringPipeline
(
Pipeline
):
"""
"""
Question Answering pipeline using ModelForQuestionAnswering head.
Question Answering pipeline using ModelForQuestionAnswering head.
...
@@ -403,8 +460,9 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -403,8 +460,9 @@ class QuestionAnsweringPipeline(Pipeline):
else
:
else
:
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]):
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
],
device
:
int
=
-
1
,
**
kwargs
):
super
().
__init__
(
model
,
tokenizer
,
args_parser
=
QuestionAnsweringPipeline
.
QuestionAnsweringArgumentHandler
())
super
().
__init__
(
model
,
tokenizer
,
args_parser
=
QuestionAnsweringArgumentHandler
(),
device
=
device
,
**
kwargs
)
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment