Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
955d7ecb
Commit
955d7ecb
authored
Dec 16, 2019
by
Morgan Funtowicz
Browse files
Refactored Pipeline with dedicated argument handler.
parent
8e3b1c86
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
112 additions
and
98 deletions
+112
-98
transformers/pipelines.py
transformers/pipelines.py
+112
-98
No files found.
transformers/pipelines.py
View file @
955d7ecb
...
...
@@ -36,29 +36,40 @@ if is_torch_available():
AutoModelForQuestionAnswering
,
AutoModelForTokenClassification
class
Pipeline
(
ABC
):
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
**
kwargs
):
self
.
model
=
model
self
.
tokenizer
=
tokenizer
class
ArgumentHandler
(
ABC
):
"""
Base interface for handling varargs for each Pipeline
"""
@
abstractmethod
def
__call__
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
save_pretrained
(
self
,
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Provided path ({}) should be a directory"
.
format
(
save_directory
))
return
self
.
model
.
save_pretrained
(
save_directory
)
self
.
tokenizer
.
save_pretrained
(
save_directory
)
class
DefaultArgumentHandler
(
ArgumentHandler
):
"""
Default varargs argument parser handling parameters for each Pipeline
"""
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
'X'
in
kwargs
:
return
kwargs
[
'X'
]
elif
'data'
in
kwargs
:
return
kwargs
[
'data'
]
elif
len
(
args
)
>
0
:
return
list
(
args
)
raise
ValueError
(
'Unable to infer the format of the provided data (X=, data=, ...)'
)
def
transform
(
self
,
*
texts
,
**
kwargs
):
# Generic compatibility with sklearn and Keras
return
self
(
*
texts
,
**
kwargs
)
def
predict
(
self
,
*
texts
,
**
kwargs
):
# Generic compatibility with sklearn and Keras
return
self
(
*
texts
,
**
kwargs
)
class
_ScikitCompat
(
ABC
):
"""
Interface layer for the Scikit and Keras compatibility.
"""
@
abstractmethod
def
__call__
(
self
,
*
texts
,
**
kwargs
):
def
transform
(
self
,
X
):
raise
NotImplementedError
()
@
abstractmethod
def
predict
(
self
,
X
):
raise
NotImplementedError
()
...
...
@@ -133,24 +144,45 @@ class JsonPipelineDataFormat(PipelineDataFormat):
if
self
.
is_multi_columns
:
yield
{
k
:
entry
[
c
]
for
k
,
c
in
self
.
column
}
else
:
yield
entry
[
self
.
column
]
yield
entry
[
self
.
column
[
0
]
]
def
save
(
self
,
data
:
dict
):
with
open
(
self
.
output
,
'w'
)
as
f
:
json
.
dump
(
data
,
f
)
class
FeatureExtractionPipeline
(
Pipeline
):
class
Pipeline
(
_ScikitCompat
):
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
**
kwargs
):
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
def
save_pretrained
(
self
,
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Provided path ({}) should be a directory"
.
format
(
save_directory
))
return
self
.
model
.
save_pretrained
(
save_directory
)
self
.
tokenizer
.
save_pretrained
(
save_directory
)
def
transform
(
self
,
X
):
return
self
(
X
=
X
)
def
predict
(
self
,
X
):
return
self
(
X
=
X
)
def
__call__
(
self
,
*
texts
,
**
kwargs
):
# Generic compatibility with sklearn and Keras
if
'X'
in
kwargs
and
not
texts
:
texts
=
kwargs
.
pop
(
'X'
)
# Parse arguments
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
# Encode for forward
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
tex
ts
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
inpu
ts
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
return
self
.
_forward
(
inputs
)
def
_forward
(
self
,
inputs
):
if
is_tf_available
():
# TODO trace model
predictions
=
self
.
model
(
inputs
)[
0
]
...
...
@@ -159,7 +191,12 @@ class FeatureExtractionPipeline(Pipeline):
with
torch
.
no_grad
():
predictions
=
self
.
model
(
**
inputs
)[
0
]
return
predictions
.
numpy
().
tolist
()
return
predictions
.
numpy
()
class
FeatureExtractionPipeline
(
Pipeline
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
class
TextClassificationPipeline
(
Pipeline
):
...
...
@@ -170,26 +207,8 @@ class TextClassificationPipeline(Pipeline):
raise
Exception
(
'Invalid parameter nb_classes. int >= 2 is required (got: {})'
.
format
(
nb_classes
))
self
.
_nb_classes
=
nb_classes
def
__call__
(
self
,
*
texts
,
**
kwargs
):
# Generic compatibility with sklearn and Keras
if
'X'
in
kwargs
and
not
texts
:
texts
=
kwargs
.
pop
(
'X'
)
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
texts
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
special_tokens_mask
=
inputs
.
pop
(
'special_tokens_mask'
)
if
is_tf_available
():
# TODO trace model
predictions
=
self
.
model
(
**
inputs
)[
0
]
else
:
import
torch
with
torch
.
no_grad
():
predictions
=
self
.
model
(
**
inputs
)[
0
]
return
predictions
.
numpy
().
tolist
()
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
class
NerPipeline
(
Pipeline
):
...
...
@@ -198,8 +217,7 @@ class NerPipeline(Pipeline):
super
().
__init__
(
model
,
tokenizer
)
def
__call__
(
self
,
*
texts
,
**
kwargs
):
(
texts
,
),
answers
=
texts
,
[]
inputs
,
answers
=
self
.
_args_parser
(
*
texts
,
**
kwargs
),
[]
for
sentence
in
texts
:
# Ugly token to word idx mapping (for now)
...
...
@@ -241,24 +259,12 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline involving Tokenization and Inference.
"""
@
classmethod
def
from_config
(
cls
,
model
,
tokenizer
:
PreTrainedTokenizer
,
**
kwargs
):
pass
@
staticmethod
def
create_sample
(
question
:
Union
[
str
,
List
[
str
]],
context
:
Union
[
str
,
List
[
str
]])
->
Union
[
SquadExample
,
List
[
SquadExample
]]:
is_list
=
isinstance
(
question
,
list
)
if
is_list
:
return
[
SquadExample
(
None
,
q
,
c
,
None
,
None
,
None
)
for
q
,
c
in
zip
(
question
,
context
)]
else
:
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
class
QuestionAnsweringArgumentHandler
(
ArgumentHandler
):
@
staticmethod
def
handle_args
(
*
inputs
,
**
kwargs
)
->
List
[
SquadExample
]:
def
__call__
(
self
,
*
args
,
**
kwargs
):
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if
input
s
is
not
None
and
len
(
input
s
)
>
1
:
kwargs
[
'X'
]
=
input
s
if
arg
s
is
not
None
and
len
(
arg
s
)
>
1
:
kwargs
[
'X'
]
=
arg
s
# Generic compatibility with sklearn and Keras
# Batched data
...
...
@@ -300,8 +306,17 @@ class QuestionAnsweringPipeline(Pipeline):
return
inputs
@
staticmethod
def
create_sample
(
question
:
Union
[
str
,
List
[
str
]],
context
:
Union
[
str
,
List
[
str
]])
->
Union
[
SquadExample
,
List
[
SquadExample
]]:
is_list
=
isinstance
(
question
,
list
)
if
is_list
:
return
[
SquadExample
(
None
,
q
,
c
,
None
,
None
,
None
)
for
q
,
c
in
zip
(
question
,
context
)]
else
:
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]):
super
().
__init__
(
model
,
tokenizer
)
super
().
__init__
(
model
,
tokenizer
,
args_parser
=
QuestionAnsweringPipeline
.
QuestionAnsweringArgumentHandler
()
)
def
inputs_for_model
(
self
,
features
:
Union
[
SquadExample
,
List
[
SquadExample
]])
->
Dict
:
args
=
[
'input_ids'
,
'attention_mask'
]
...
...
@@ -332,9 +347,8 @@ class QuestionAnsweringPipeline(Pipeline):
if
kwargs
[
'max_answer_len'
]
<
1
:
raise
ValueError
(
'max_answer_len parameter should be >= 1 (got {})'
.
format
(
kwargs
[
'max_answer_len'
]))
examples
=
QuestionAnsweringPipeline
.
handle_args
(
texts
,
**
kwargs
)
# Convert inputs to features
examples
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
fw_args
=
self
.
inputs_for_model
(
features
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment