Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1fa93ca1
Commit
1fa93ca1
authored
Dec 20, 2019
by
thomwolf
Browse files
Clean up framework handling
parent
ca6bdb28
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
27 deletions
+58
-27
transformers/pipelines.py
transformers/pipelines.py
+58
-27
No files found.
transformers/pipelines.py
View file @
1fa93ca1
...
...
@@ -48,6 +48,19 @@ if is_torch_available():
logger
=
logging
.
getLogger
(
__name__
)
def
get_framework
(
model
=
None
):
if
is_tf_available
()
and
is_torch_available
()
and
model
is
not
None
and
not
isinstance
(
model
,
str
):
# Both framework are available but the use supplied a model class instance.
# Try to guess which framework to use from the model classname
framework
=
'tf'
if
model
.
__class__
.
__name__
.
startswith
(
'TF'
)
else
'pt'
else
:
framework
=
'tf'
if
is_tf_available
()
else
(
'pt'
if
is_torch_available
()
else
None
)
if
framework
is
None
:
raise
ImportError
(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
return
framework
class
ArgumentHandler
(
ABC
):
"""
Base interface for handling varargs for each Pipeline
...
...
@@ -279,19 +292,23 @@ class Pipeline(_ScikitCompat):
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
"""
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
modelcard
:
ModelCard
=
None
,
modelcard
:
ModelCard
=
None
,
framework
:
Optional
[
str
]
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
binary_output
:
bool
=
False
):
if
framework
is
None
:
framework
=
get_framework
()
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
modelcard
=
modelcard
self
.
framework
=
framework
self
.
device
=
device
self
.
binary_output
=
binary_output
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
# Special handling
if
self
.
device
>=
0
and
not
is_tf_available
()
:
if
self
.
device
>=
0
and
self
.
framework
==
'pt'
:
self
.
model
=
self
.
model
.
to
(
'cuda:{}'
.
format
(
self
.
device
))
def
save_pretrained
(
self
,
save_directory
):
...
...
@@ -332,7 +349,7 @@ class Pipeline(_ScikitCompat):
Returns:
Context manager
"""
if
is_tf_available
()
:
if
self
.
framework
==
'tf'
:
with
tf
.
device
(
'/CPU:0'
if
self
.
device
==
-
1
else
'/device:GPU:{}'
.
format
(
self
.
device
)):
yield
else
:
...
...
@@ -371,7 +388,7 @@ class Pipeline(_ScikitCompat):
with
self
.
device_placement
():
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
return_tensors
=
self
.
framework
,
max_length
=
self
.
tokenizer
.
max_len
)
...
...
@@ -387,7 +404,7 @@ class Pipeline(_ScikitCompat):
Returns:
Numpy array
"""
if
is_tf_available
()
:
if
self
.
framework
==
'tf'
:
# TODO trace model
predictions
=
self
.
model
(
inputs
,
training
=
False
)[
0
]
else
:
...
...
@@ -405,9 +422,16 @@ class FeatureExtractionPipeline(Pipeline):
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
modelcard
:
ModelCard
=
None
,
framework
:
Optional
[
str
]
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
):
super
().
__init__
(
model
,
tokenizer
,
modelcard
,
args_parser
,
device
,
binary_output
=
True
)
super
().
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
modelcard
=
modelcard
,
framework
=
framework
,
args_parser
=
args_parser
,
device
=
device
,
binary_output
=
True
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
...
...
@@ -430,10 +454,16 @@ class NerPipeline(Pipeline):
"""
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
modelcard
:
ModelCard
=
None
,
modelcard
:
ModelCard
=
None
,
framework
:
Optional
[
str
]
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
binary_output
:
bool
=
False
):
super
().
__init__
(
model
,
tokenizer
,
modelcard
,
args_parser
,
device
,
binary_output
)
super
().
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
modelcard
=
modelcard
,
framework
=
framework
,
args_parser
=
args_parser
,
device
=
device
,
binary_output
=
binary_output
)
self
.
_basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
False
)
...
...
@@ -452,12 +482,12 @@ class NerPipeline(Pipeline):
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
return_tensors
=
self
.
framework
,
max_length
=
self
.
tokenizer
.
max_len
)
# Forward
if
is_tf_available
()
:
if
self
.
framework
==
'tf'
:
entities
=
self
.
model
(
tokens
)[
0
][
0
].
numpy
()
else
:
with
torch
.
no_grad
():
...
...
@@ -549,6 +579,18 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline using ModelForQuestionAnswering head.
"""
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
],
modelcard
:
Optional
[
ModelCard
],
framework
:
Optional
[
str
]
=
None
,
device
:
int
=
-
1
,
**
kwargs
):
super
().
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
modelcard
=
modelcard
,
framework
=
framework
,
args_parser
=
QuestionAnsweringArgumentHandler
(),
device
=
device
,
**
kwargs
)
@
staticmethod
def
create_sample
(
question
:
Union
[
str
,
List
[
str
]],
context
:
Union
[
str
,
List
[
str
]])
->
Union
[
SquadExample
,
List
[
SquadExample
]]:
"""
...
...
@@ -567,12 +609,6 @@ class QuestionAnsweringPipeline(Pipeline):
else
:
return
SquadExample
(
None
,
question
,
context
,
None
,
None
,
None
)
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
],
modelcard
:
Optional
[
ModelCard
],
device
:
int
=
-
1
,
**
kwargs
):
super
().
__init__
(
model
,
tokenizer
,
modelcard
,
args_parser
=
QuestionAnsweringArgumentHandler
(),
device
=
device
,
**
kwargs
)
def
__call__
(
self
,
*
texts
,
**
kwargs
):
"""
Args:
...
...
@@ -608,7 +644,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Manage tensor allocation on correct device
with
self
.
device_placement
():
if
is_tf_available
()
:
if
self
.
framework
==
'tf'
:
fw_args
=
{
k
:
tf
.
constant
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
start
,
end
=
self
.
model
(
fw_args
)
start
,
end
=
start
.
numpy
(),
end
.
numpy
()
...
...
@@ -798,15 +834,10 @@ def pipeline(task: str, model: Optional = None,
if
task
not
in
SUPPORTED_TASKS
:
raise
KeyError
(
"Unknown task {}, available tasks are {}"
.
format
(
task
,
list
(
SUPPORTED_TASKS
.
keys
())))
pipeline_framework
=
'tf'
if
is_tf_available
()
else
(
'pt'
if
is_torch_available
()
else
None
)
if
pipeline_framework
is
None
:
raise
ImportError
(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
framework
=
get_framework
(
model
)
targeted_task
=
SUPPORTED_TASKS
[
task
]
task
,
model_class
=
targeted_task
[
'impl'
],
targeted_task
[
pipeline_
framework
]
task
,
model_class
=
targeted_task
[
'impl'
],
targeted_task
[
framework
]
# Use default model/config/tokenizer for the task if no model is provided
if
model
is
None
:
...
...
@@ -843,14 +874,14 @@ def pipeline(task: str, model: Optional = None,
if
isinstance
(
model
,
str
):
# Handle transparent TF/PT model conversion
model_kwargs
=
{}
if
pipeline_
framework
==
'pt'
and
model
.
endswith
(
'.h5'
):
if
framework
==
'pt'
and
model
.
endswith
(
'.h5'
):
model_kwargs
[
'from_tf'
]
=
True
logger
.
warning
(
'Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. '
'Trying to load the model with PyTorch.'
)
elif
pipeline_
framework
==
'tf'
and
model
.
endswith
(
'.bin'
):
elif
framework
==
'tf'
and
model
.
endswith
(
'.bin'
):
model_kwargs
[
'from_pt'
]
=
True
logger
.
warning
(
'Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. '
'Trying to load the model with Tensorflow.'
)
model
=
model_class
.
from_pretrained
(
model
,
config
=
config
,
**
model_kwargs
)
return
task
(
model
,
tokenizer
,
**
kwargs
)
return
task
(
model
=
model
,
tokenizer
=
tokenizer
,
modelcard
=
modelcard
,
framework
=
framework
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment