Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
955d7ecb
Commit
955d7ecb
authored
Dec 16, 2019
by
Morgan Funtowicz
Browse files
Refactored Pipeline with dedicated argument handler.
parent
8e3b1c86
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
112 additions
and
98 deletions
+112
-98
transformers/pipelines.py
transformers/pipelines.py
+112
-98
No files found.
transformers/pipelines.py
View file @
955d7ecb
...
@@ -36,29 +36,40 @@ if is_torch_available():
...
@@ -36,29 +36,40 @@ if is_torch_available():
AutoModelForQuestionAnswering
,
AutoModelForTokenClassification
AutoModelForQuestionAnswering
,
AutoModelForTokenClassification
class
Pipeline
(
ABC
):
class
ArgumentHandler
(
ABC
):
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
**
kwargs
):
"""
self
.
model
=
model
Base interface for handling varargs for each Pipeline
self
.
tokenizer
=
tokenizer
"""
@
abstractmethod
def
__call__
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
save_pretrained
(
self
,
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Provided path ({}) should be a directory"
.
format
(
save_directory
))
return
self
.
model
.
save_pretrained
(
save_directory
)
class
DefaultArgumentHandler
(
ArgumentHandler
):
self
.
tokenizer
.
save_pretrained
(
save_directory
)
"""
Default varargs argument parser handling parameters for each Pipeline
"""
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
'X'
in
kwargs
:
return
kwargs
[
'X'
]
elif
'data'
in
kwargs
:
return
kwargs
[
'data'
]
elif
len
(
args
)
>
0
:
return
list
(
args
)
raise
ValueError
(
'Unable to infer the format of the provided data (X=, data=, ...)'
)
def
transform
(
self
,
*
texts
,
**
kwargs
):
# Generic compatibility with sklearn and Keras
return
self
(
*
texts
,
**
kwargs
)
def
predict
(
self
,
*
texts
,
**
kwargs
):
class
_ScikitCompat
(
ABC
):
# Generic compatibility with sklearn and Keras
"""
return
self
(
*
texts
,
**
kwargs
)
Interface layer for the Scikit and Keras compatibility.
"""
@
abstractmethod
@
abstractmethod
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
transform
(
self
,
X
):
raise
NotImplementedError
()
@
abstractmethod
def
predict
(
self
,
X
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -133,24 +144,45 @@ class JsonPipelineDataFormat(PipelineDataFormat):
...
@@ -133,24 +144,45 @@ class JsonPipelineDataFormat(PipelineDataFormat):
if
self
.
is_multi_columns
:
if
self
.
is_multi_columns
:
yield
{
k
:
entry
[
c
]
for
k
,
c
in
self
.
column
}
yield
{
k
:
entry
[
c
]
for
k
,
c
in
self
.
column
}
else
:
else
:
yield
entry
[
self
.
column
]
yield
entry
[
self
.
column
[
0
]
]
def
save
(
self
,
data
:
dict
):
def
save
(
self
,
data
:
dict
):
with
open
(
self
.
output
,
'w'
)
as
f
:
with
open
(
self
.
output
,
'w'
)
as
f
:
json
.
dump
(
data
,
f
)
json
.
dump
(
data
,
f
)
class
FeatureExtractionPipeline
(
Pipeline
):
class
Pipeline
(
_ScikitCompat
):
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
**
kwargs
):
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
def
save_pretrained
(
self
,
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Provided path ({}) should be a directory"
.
format
(
save_directory
))
return
self
.
model
.
save_pretrained
(
save_directory
)
self
.
tokenizer
.
save_pretrained
(
save_directory
)
def
transform
(
self
,
X
):
return
self
(
X
=
X
)
def
predict
(
self
,
X
):
return
self
(
X
=
X
)
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
__call__
(
self
,
*
texts
,
**
kwargs
):
# Generic compatibility with sklearn and Keras
# Parse arguments
if
'X'
in
kwargs
and
not
texts
:
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
texts
=
kwargs
.
pop
(
'X'
)
# Encode for forward
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
tex
ts
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
inpu
ts
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
)
return
self
.
_forward
(
inputs
)
def
_forward
(
self
,
inputs
):
if
is_tf_available
():
if
is_tf_available
():
# TODO trace model
# TODO trace model
predictions
=
self
.
model
(
inputs
)[
0
]
predictions
=
self
.
model
(
inputs
)[
0
]
...
@@ -159,7 +191,12 @@ class FeatureExtractionPipeline(Pipeline):
...
@@ -159,7 +191,12 @@ class FeatureExtractionPipeline(Pipeline):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
predictions
=
self
.
model
(
**
inputs
)[
0
]
predictions
=
self
.
model
(
**
inputs
)[
0
]
return
predictions
.
numpy
().
tolist
()
return
predictions
.
numpy
()
class
FeatureExtractionPipeline
(
Pipeline
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
class
TextClassificationPipeline
(
Pipeline
):
class
TextClassificationPipeline
(
Pipeline
):
...
@@ -170,26 +207,8 @@ class TextClassificationPipeline(Pipeline):
...
@@ -170,26 +207,8 @@ class TextClassificationPipeline(Pipeline):
raise
Exception
(
'Invalid parameter nb_classes. int >= 2 is required (got: {})'
.
format
(
nb_classes
))
raise
Exception
(
'Invalid parameter nb_classes. int >= 2 is required (got: {})'
.
format
(
nb_classes
))
self
.
_nb_classes
=
nb_classes
self
.
_nb_classes
=
nb_classes
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
# Generic compatibility with sklearn and Keras
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
if
'X'
in
kwargs
and
not
texts
:
texts
=
kwargs
.
pop
(
'X'
)
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
texts
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
special_tokens_mask
=
inputs
.
pop
(
'special_tokens_mask'
)
if
is_tf_available
():
# TODO trace model
predictions
=
self
.
model
(
**
inputs
)[
0
]
else
:
import
torch
with
torch
.
no_grad
():
predictions
=
self
.
model
(
**
inputs
)[
0
]
return
predictions
.
numpy
().
tolist
()
class
NerPipeline
(
Pipeline
):
class
NerPipeline
(
Pipeline
):
...
@@ -198,8 +217,7 @@ class NerPipeline(Pipeline):
...
@@ -198,8 +217,7 @@ class NerPipeline(Pipeline):
super
().
__init__
(
model
,
tokenizer
)
super
().
__init__
(
model
,
tokenizer
)
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
__call__
(
self
,
*
texts
,
**
kwargs
):
(
texts
,
),
answers
=
texts
,
[]
inputs
,
answers
=
self
.
_args_parser
(
*
texts
,
**
kwargs
),
[]
for
sentence
in
texts
:
for
sentence
in
texts
:
# Ugly token to word idx mapping (for now)
# Ugly token to word idx mapping (for now)
...
@@ -241,9 +259,52 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -241,9 +259,52 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline involving Tokenization and Inference.
Question Answering pipeline involving Tokenization and Inference.
"""
"""
@
classmethod
class
QuestionAnsweringArgumentHandler
(
ArgumentHandler
):
def
from_config
(
cls
,
model
,
tokenizer
:
PreTrainedTokenizer
,
**
kwargs
):
pass
def
__call__
(
self
,
*
args
,
**
kwargs
):
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if
args
is
not
None
and
len
(
args
)
>
1
:
kwargs
[
'X'
]
=
args
# Generic compatibility with sklearn and Keras
# Batched data
if
'X'
in
kwargs
or
'data'
in
kwargs
:
data
=
kwargs
[
'X'
]
if
'X'
in
kwargs
else
kwargs
[
'data'
]
if
not
isinstance
(
data
,
list
):
data
=
[
data
]
for
i
,
item
in
enumerate
(
data
):
if
isinstance
(
item
,
dict
):
if
any
(
k
not
in
item
for
k
in
[
'question'
,
'context'
]):
raise
KeyError
(
'You need to provide a dictionary with keys {question:..., context:...}'
)
data
[
i
]
=
QuestionAnsweringPipeline
.
create_sample
(
**
item
)
elif
isinstance
(
item
,
SquadExample
):
continue
else
:
raise
ValueError
(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.
format
(
'X'
if
'X'
in
kwargs
else
'data'
)
)
inputs
=
data
# Tabular input
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
if
isinstance
(
kwargs
[
'question'
],
str
):
kwargs
[
'question'
]
=
[
kwargs
[
'question'
]]
if
isinstance
(
kwargs
[
'context'
],
str
):
kwargs
[
'context'
]
=
[
kwargs
[
'context'
]]
inputs
=
[
QuestionAnsweringPipeline
.
create_sample
(
q
,
c
)
for
q
,
c
in
zip
(
kwargs
[
'question'
],
kwargs
[
'context'
])]
else
:
raise
ValueError
(
'Unknown arguments {}'
.
format
(
kwargs
))
if
not
isinstance
(
inputs
,
list
):
inputs
=
[
inputs
]
return
inputs
@
staticmethod
@
staticmethod
def
create_sample
(
question
:
Union
[
str
,
List
[
str
]],
context
:
Union
[
str
,
List
[
str
]])
->
Union
[
SquadExample
,
List
[
SquadExample
]]:
def
create_sample
(
question
:
Union
[
str
,
List
[
str
]],
context
:
Union
[
str
,
List
[
str
]])
->
Union
[
SquadExample
,
List
[
SquadExample
]]:
...
@@ -254,54 +315,8 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -254,54 +315,8 @@ class QuestionAnsweringPipeline(Pipeline):
else
:
else
:
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
@
staticmethod
def
handle_args
(
*
inputs
,
**
kwargs
)
->
List
[
SquadExample
]:
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if
inputs
is
not
None
and
len
(
inputs
)
>
1
:
kwargs
[
'X'
]
=
inputs
# Generic compatibility with sklearn and Keras
# Batched data
if
'X'
in
kwargs
or
'data'
in
kwargs
:
data
=
kwargs
[
'X'
]
if
'X'
in
kwargs
else
kwargs
[
'data'
]
if
not
isinstance
(
data
,
list
):
data
=
[
data
]
for
i
,
item
in
enumerate
(
data
):
if
isinstance
(
item
,
dict
):
if
any
(
k
not
in
item
for
k
in
[
'question'
,
'context'
]):
raise
KeyError
(
'You need to provide a dictionary with keys {question:..., context:...}'
)
data
[
i
]
=
QuestionAnsweringPipeline
.
create_sample
(
**
item
)
elif
isinstance
(
item
,
SquadExample
):
continue
else
:
raise
ValueError
(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.
format
(
'X'
if
'X'
in
kwargs
else
'data'
)
)
inputs
=
data
# Tabular input
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
if
isinstance
(
kwargs
[
'question'
],
str
):
kwargs
[
'question'
]
=
[
kwargs
[
'question'
]]
if
isinstance
(
kwargs
[
'context'
],
str
):
kwargs
[
'context'
]
=
[
kwargs
[
'context'
]]
inputs
=
[
QuestionAnsweringPipeline
.
create_sample
(
q
,
c
)
for
q
,
c
in
zip
(
kwargs
[
'question'
],
kwargs
[
'context'
])]
else
:
raise
ValueError
(
'Unknown arguments {}'
.
format
(
kwargs
))
if
not
isinstance
(
inputs
,
list
):
inputs
=
[
inputs
]
return
inputs
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]):
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]):
super
().
__init__
(
model
,
tokenizer
)
super
().
__init__
(
model
,
tokenizer
,
args_parser
=
QuestionAnsweringPipeline
.
QuestionAnsweringArgumentHandler
()
)
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
args
=
[
'input_ids'
,
'attention_mask'
]
args
=
[
'input_ids'
,
'attention_mask'
]
...
@@ -332,9 +347,8 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -332,9 +347,8 @@ class QuestionAnsweringPipeline(Pipeline):
if
kwargs
[
'max_answer_len'
]
<
1
:
if
kwargs
[
'max_answer_len'
]
<
1
:
raise
ValueError
(
'max_answer_len parameter should be >= 1 (got {})'
.
format
(
kwargs
[
'max_answer_len'
]))
raise
ValueError
(
'max_answer_len parameter should be >= 1 (got {})'
.
format
(
kwargs
[
'max_answer_len'
]))
examples
=
QuestionAnsweringPipeline
.
handle_args
(
texts
,
**
kwargs
)
# Convert inputs to features
# Convert inputs to features
examples
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
fw_args
=
self
.
inputs_for_model
(
features
)
fw_args
=
self
.
inputs_for_model
(
features
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment