Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9a24e0cf
Commit
9a24e0cf
authored
Dec 11, 2019
by
Morgan Funtowicz
Browse files
Refactored qa pipeline argument handling + unittests
parent
63e36007
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
33 deletions
+87
-33
transformers/pipelines.py
transformers/pipelines.py
+55
-32
transformers/tests/pipelines_test.py
transformers/tests/pipelines_test.py
+32
-1
No files found.
transformers/pipelines.py
View file @
9a24e0cf
...
@@ -98,14 +98,11 @@ class TextClassificationPipeline(Pipeline):
...
@@ -98,14 +98,11 @@ class TextClassificationPipeline(Pipeline):
class
QuestionAnsweringPipeline
(
Pipeline
):
class
QuestionAnsweringPipeline
(
Pipeline
):
"""
"""
Question Answering pipeling involving Tokenization and Inference.
Question Answering pipeling involving Tokenization and Inference.
TODO:
- top-k answers
- return start/end chars
- return score
"""
"""
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]):
@
classmethod
super
().
__init__
(
model
,
tokenizer
)
def
from_config
(
cls
,
model
,
tokenizer
:
PreTrainedTokenizer
,
**
kwargs
):
pass
@
staticmethod
@
staticmethod
def
create_sample
(
question
:
Union
[
str
,
List
[
str
]],
context
:
Union
[
str
,
List
[
str
]])
->
Union
[
SquadExample
,
List
[
SquadExample
]]:
def
create_sample
(
question
:
Union
[
str
,
List
[
str
]],
context
:
Union
[
str
,
List
[
str
]])
->
Union
[
SquadExample
,
List
[
SquadExample
]]:
...
@@ -116,6 +113,55 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -116,6 +113,55 @@ class QuestionAnsweringPipeline(Pipeline):
else
:
else
:
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
@
staticmethod
def
handle_args
(
*
inputs
,
**
kwargs
)
->
List
[
SquadExample
]:
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if
inputs
is
not
None
and
len
(
inputs
)
>
1
:
kwargs
[
'X'
]
=
inputs
# Generic compatibility with sklearn and Keras
# Batched data
if
'X'
in
kwargs
or
'data'
in
kwargs
:
data
=
kwargs
[
'X'
]
if
'X'
in
kwargs
else
kwargs
[
'data'
]
if
not
isinstance
(
data
,
list
):
data
=
[
data
]
for
i
,
item
in
enumerate
(
data
):
if
isinstance
(
item
,
dict
):
if
any
(
k
not
in
item
for
k
in
[
'question'
,
'context'
]):
raise
KeyError
(
'You need to provide a dictionary with keys {question:..., context:...}'
)
data
[
i
]
=
QuestionAnsweringPipeline
.
create_sample
(
**
item
)
elif
isinstance
(
item
,
SquadExample
):
continue
else
:
raise
ValueError
(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.
format
(
'X'
if
'X'
in
kwargs
else
'data'
)
)
inputs
=
data
# Tabular input
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
if
isinstance
(
kwargs
[
'question'
],
str
):
kwargs
[
'question'
]
=
[
kwargs
[
'question'
]]
if
isinstance
(
kwargs
[
'context'
],
str
):
kwargs
[
'context'
]
=
[
kwargs
[
'context'
]]
inputs
=
[
QuestionAnsweringPipeline
.
create_sample
(
q
,
c
)
for
q
,
c
in
zip
(
kwargs
[
'question'
],
kwargs
[
'context'
])]
else
:
raise
ValueError
(
'Unknown arguments {}'
.
format
(
kwargs
))
if
not
isinstance
(
inputs
,
list
):
inputs
=
[
inputs
]
return
inputs
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]):
super
().
__init__
(
model
,
tokenizer
)
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
args
=
[
'input_ids'
,
'attention_mask'
]
args
=
[
'input_ids'
,
'attention_mask'
]
model_type
=
type
(
self
.
model
).
__name__
.
lower
()
model_type
=
type
(
self
.
model
).
__name__
.
lower
()
...
@@ -131,10 +177,6 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -131,10 +177,6 @@ class QuestionAnsweringPipeline(Pipeline):
else
:
else
:
return
{
k
:
[
feature
.
__dict__
[
k
]
for
feature
in
features
]
for
k
in
args
}
return
{
k
:
[
feature
.
__dict__
[
k
]
for
feature
in
features
]
for
k
in
args
}
@
classmethod
def
from_config
(
cls
,
model
,
tokenizer
:
PreTrainedTokenizer
,
**
kwargs
):
pass
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
__call__
(
self
,
*
texts
,
**
kwargs
):
# Set defaults values
# Set defaults values
kwargs
.
setdefault
(
'topk'
,
1
)
kwargs
.
setdefault
(
'topk'
,
1
)
...
@@ -149,29 +191,10 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -149,29 +191,10 @@ class QuestionAnsweringPipeline(Pipeline):
if
kwargs
[
'max_answer_len'
]
<
1
:
if
kwargs
[
'max_answer_len'
]
<
1
:
raise
ValueError
(
'max_answer_len parameter should be >= 1 (got {})'
.
format
(
kwargs
[
'max_answer_len'
]))
raise
ValueError
(
'max_answer_len parameter should be >= 1 (got {})'
.
format
(
kwargs
[
'max_answer_len'
]))
# Position args
examples
=
QuestionAnsweringPipeline
.
handle_args
(
texts
,
**
kwargs
)
if
texts
is
not
None
and
len
(
texts
)
>
1
:
(
texts
,
)
=
texts
# Generic compatibility with sklearn and Keras
elif
'X'
in
kwargs
and
not
texts
:
texts
=
kwargs
.
pop
(
'X'
)
# Batched data
elif
'data'
in
kwargs
:
texts
=
kwargs
.
pop
(
'data'
)
# Tabular input
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
texts
=
QuestionAnsweringPipeline
.
create_sample
(
kwargs
[
'question'
],
kwargs
[
'context'
])
else
:
raise
ValueError
(
'Unknown arguments {}'
.
format
(
kwargs
))
if
not
isinstance
(
texts
,
list
):
texts
=
[
texts
]
# Convert inputs to features
# Convert inputs to features
features
=
squad_convert_examples_to_features
(
t
ex
t
s
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
features
=
squad_convert_examples_to_features
(
ex
ample
s
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
fw_args
=
self
.
inputs_for_model
(
features
)
fw_args
=
self
.
inputs_for_model
(
features
)
if
is_tf_available
():
if
is_tf_available
():
...
@@ -188,7 +211,7 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -188,7 +211,7 @@ class QuestionAnsweringPipeline(Pipeline):
start
,
end
=
start
.
cpu
().
numpy
(),
end
.
cpu
().
numpy
()
start
,
end
=
start
.
cpu
().
numpy
(),
end
.
cpu
().
numpy
()
answers
=
[]
answers
=
[]
for
(
example
,
feature
,
start_
,
end_
)
in
zip
(
t
ex
t
s
,
features
,
start
,
end
):
for
(
example
,
feature
,
start_
,
end_
)
in
zip
(
ex
ample
s
,
features
,
start
,
end
):
# Normalize logits and spans to retrieve the answer
# Normalize logits and spans to retrieve the answer
start_
=
np
.
exp
(
start_
)
/
np
.
sum
(
np
.
exp
(
start_
))
start_
=
np
.
exp
(
start_
)
/
np
.
sum
(
np
.
exp
(
start_
))
end_
=
np
.
exp
(
end_
)
/
np
.
sum
(
np
.
exp
(
end_
))
end_
=
np
.
exp
(
end_
)
/
np
.
sum
(
np
.
exp
(
end_
))
...
...
transformers/tests/pipelines_test.py
View file @
9a24e0cf
...
@@ -40,7 +40,38 @@ class QuestionAnsweringPipelineTest(unittest.TestCase):
...
@@ -40,7 +40,38 @@ class QuestionAnsweringPipelineTest(unittest.TestCase):
# Batch case with topk = 2
# Batch case with topk = 2
a
=
nlp
(
question
=
[
'What is the name of the company I
\'
m working for ?'
,
'Where is the company based ?'
],
a
=
nlp
(
question
=
[
'What is the name of the company I
\'
m working for ?'
,
'Where is the company based ?'
],
context
=
[
'I
\'
m working for Huggingface.'
,
'The company is based in New York and Paris'
],
topk
=
2
)
context
=
[
'Where is the company based ?'
,
'The company is based in New York and Paris'
],
topk
=
2
)
self
.
check_answer_structure
(
a
,
2
,
2
)
# check for data keyword
a
=
nlp
(
data
=
nlp
.
create_sample
(
question
=
'What is the name of the company I
\'
m working for ?'
,
context
=
'I
\'
m working for Huggingface.'
))
self
.
check_answer_structure
(
a
,
1
,
1
)
a
=
nlp
(
data
=
nlp
.
create_sample
(
question
=
'What is the name of the company I
\'
m working for ?'
,
context
=
'I
\'
m working for Huggingface.'
),
topk
=
2
)
self
.
check_answer_structure
(
a
,
1
,
2
)
a
=
nlp
(
data
=
[
nlp
.
create_sample
(
question
=
'What is the name of the company I
\'
m working for ?'
,
context
=
'I
\'
m working for Huggingface.'
),
nlp
.
create_sample
(
question
=
'I
\'
m working for Huggingface.'
,
context
=
'The company is based in New York and Paris'
),
])
self
.
check_answer_structure
(
a
,
2
,
1
)
a
=
nlp
(
data
=
[
{
'question'
:
'What is the name of the company I
\'
m working for ?'
,
'context'
:
'I
\'
m working for Huggingface.'
},
{
'question'
:
'Where is the company based ?'
,
'context'
:
'The company is based in New York and Paris'
},
])
self
.
check_answer_structure
(
a
,
2
,
1
)
# X keywords
a
=
nlp
(
X
=
nlp
.
create_sample
(
question
=
'Where is the company based ?'
,
context
=
'The company is based in New York and Paris'
))
self
.
check_answer_structure
(
a
,
1
,
1
)
a
=
nlp
(
X
=
[
{
'question'
:
'What is the name of the company I
\'
m working for ?'
,
'context'
:
'I
\'
m working for Huggingface.'
},
{
'question'
:
'Where is the company based ?'
,
'context'
:
'The company is based in New York and Paris'
},
],
topk
=
2
)
self
.
check_answer_structure
(
a
,
2
,
2
)
self
.
check_answer_structure
(
a
,
2
,
2
)
@
patch
(
'transformers.pipelines.is_torch_available'
,
return_value
=
False
)
@
patch
(
'transformers.pipelines.is_torch_available'
,
return_value
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment