Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3492a6ec
Commit
3492a6ec
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Addressing Thom's comments.
parent
81a911cc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
37 deletions
+33
-37
transformers/pipelines.py
transformers/pipelines.py
+33
-37
No files found.
transformers/pipelines.py
View file @
3492a6ec
...
@@ -30,6 +30,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer, Pretrai
...
@@ -30,6 +30,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer, Pretrai
SquadExample
,
squad_convert_examples_to_features
,
is_tf_available
,
is_torch_available
,
logger
SquadExample
,
squad_convert_examples_to_features
,
is_tf_available
,
is_torch_available
,
logger
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers
import
TFAutoModel
,
TFAutoModelForSequenceClassification
,
\
from
transformers
import
TFAutoModel
,
TFAutoModelForSequenceClassification
,
\
TFAutoModelForQuestionAnswering
,
TFAutoModelForTokenClassification
TFAutoModelForQuestionAnswering
,
TFAutoModelForTokenClassification
...
@@ -79,9 +80,9 @@ class PipelineDataFormat:
...
@@ -79,9 +80,9 @@ class PipelineDataFormat:
"""
"""
SUPPORTED_FORMATS
=
[
'json'
,
'csv'
,
'pipe'
]
SUPPORTED_FORMATS
=
[
'json'
,
'csv'
,
'pipe'
]
def
__init__
(
self
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
self
.
output
=
output
self
.
output
=
output
self
.
path
=
path
self
.
path
=
input
self
.
column
=
column
.
split
(
','
)
if
column
else
[
''
]
self
.
column
=
column
.
split
(
','
)
if
column
else
[
''
]
self
.
is_multi_columns
=
len
(
self
.
column
)
>
1
self
.
is_multi_columns
=
len
(
self
.
column
)
>
1
...
@@ -92,7 +93,7 @@ class PipelineDataFormat:
...
@@ -92,7 +93,7 @@ class PipelineDataFormat:
if
exists
(
abspath
(
self
.
output
)):
if
exists
(
abspath
(
self
.
output
)):
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
if
path
is
not
None
:
if
input
is
not
None
:
if
not
exists
(
abspath
(
self
.
path
)):
if
not
exists
(
abspath
(
self
.
path
)):
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
...
@@ -136,8 +137,8 @@ class PipelineDataFormat:
...
@@ -136,8 +137,8 @@ class PipelineDataFormat:
class
CsvPipelineDataFormat
(
PipelineDataFormat
):
class
CsvPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
,
path
,
column
)
super
().
__init__
(
output
,
input
,
column
)
def
__iter__
(
self
):
def
__iter__
(
self
):
with
open
(
self
.
path
,
'r'
)
as
f
:
with
open
(
self
.
path
,
'r'
)
as
f
:
...
@@ -157,10 +158,10 @@ class CsvPipelineDataFormat(PipelineDataFormat):
...
@@ -157,10 +158,10 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class
JsonPipelineDataFormat
(
PipelineDataFormat
):
class
JsonPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
,
path
,
column
)
super
().
__init__
(
output
,
input
,
column
)
with
open
(
path
,
'r'
)
as
f
:
with
open
(
input
,
'r'
)
as
f
:
self
.
_entries
=
json
.
load
(
f
)
self
.
_entries
=
json
.
load
(
f
)
def
__iter__
(
self
):
def
__iter__
(
self
):
...
@@ -321,11 +322,9 @@ class Pipeline(_ScikitCompat):
...
@@ -321,11 +322,9 @@ class Pipeline(_ScikitCompat):
Context manager
Context manager
"""
"""
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
with
tf
.
device
(
'/CPU:0'
if
self
.
device
==
-
1
else
'/device:GPU:{}'
.
format
(
self
.
device
)):
with
tf
.
device
(
'/CPU:0'
if
self
.
device
==
-
1
else
'/device:GPU:{}'
.
format
(
self
.
device
)):
yield
yield
else
:
else
:
import
torch
if
self
.
device
>=
0
:
if
self
.
device
>=
0
:
torch
.
cuda
.
set_device
(
self
.
device
)
torch
.
cuda
.
set_device
(
self
.
device
)
...
@@ -358,11 +357,10 @@ class Pipeline(_ScikitCompat):
...
@@ -358,11 +357,10 @@ class Pipeline(_ScikitCompat):
# Encode for forward
# Encode for forward
with
self
.
device_placement
():
with
self
.
device_placement
():
# TODO : Remove this 512 hard-limit
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
,
add_special_tokens
=
True
,
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
max_length
=
512
max_length
=
self
.
tokenizer
.
max_len
)
)
# Filter out features not available on specific models
# Filter out features not available on specific models
...
@@ -379,11 +377,10 @@ class Pipeline(_ScikitCompat):
...
@@ -379,11 +377,10 @@ class Pipeline(_ScikitCompat):
"""
"""
if
is_tf_available
():
if
is_tf_available
():
# TODO trace model
# TODO trace model
predictions
=
self
.
model
(
inputs
)[
0
]
predictions
=
self
.
model
(
inputs
,
training
=
False
)[
0
]
else
:
else
:
import
torch
with
torch
.
no_grad
():
with
torch
.
no_grad
():
predictions
=
self
.
model
(
**
inputs
)[
0
]
predictions
=
self
.
model
(
**
inputs
)
.
cpu
()
[
0
]
return
predictions
.
numpy
()
return
predictions
.
numpy
()
...
@@ -432,19 +429,18 @@ class NerPipeline(Pipeline):
...
@@ -432,19 +429,18 @@ class NerPipeline(Pipeline):
# Manage correct placement of the tensors
# Manage correct placement of the tensors
with
self
.
device_placement
():
with
self
.
device_placement
():
# TODO : Remove this 512 hard-limit
tokens
=
self
.
tokenizer
.
encode_plus
(
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
max_length
=
512
max_length
=
self
.
tokenizer
.
max_len
)
)
# Forward
# Forward
if
is_torch_available
():
if
is_tf_available
():
entities
=
self
.
model
(
**
tokens
)[
0
][
0
].
numpy
()
else
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
entities
=
self
.
model
(
**
tokens
)[
0
][
0
].
cpu
().
numpy
()
entities
=
self
.
model
(
**
tokens
)[
0
][
0
].
cpu
().
numpy
()
else
:
entities
=
self
.
model
(
tokens
)[
0
][
0
].
numpy
()
# Normalize scores
# Normalize scores
answer
,
token_start
=
[],
1
answer
,
token_start
=
[],
1
...
@@ -484,28 +480,29 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
...
@@ -484,28 +480,29 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
else
:
else
:
kwargs
[
'X'
]
=
list
(
args
)
kwargs
[
'X'
]
=
list
(
args
)
# Generic compatibility with sklearn and Keras
# Generic compatibility with sklearn and Keras
# Batched data
# Batched data
if
'X'
in
kwargs
or
'data'
in
kwargs
:
if
'X'
in
kwargs
or
'data'
in
kwargs
:
data
=
kwargs
[
'X'
]
if
'X'
in
kwargs
else
kwargs
[
'data'
]
inputs
=
kwargs
[
'X'
]
if
'X'
in
kwargs
else
kwargs
[
'data'
]
if
not
isinstance
(
data
,
list
):
if
isinstance
(
inputs
,
dict
):
data
=
[
data
]
inputs
=
[
inputs
]
else
:
# Copy to avoid overriding arguments
inputs
=
[
i
for
i
in
inputs
]
for
i
,
item
in
enumerate
(
data
):
for
i
,
item
in
enumerate
(
inputs
):
if
isinstance
(
item
,
dict
):
if
isinstance
(
item
,
dict
):
if
any
(
k
not
in
item
for
k
in
[
'question'
,
'context'
]):
if
any
(
k
not
in
item
for
k
in
[
'question'
,
'context'
]):
raise
KeyError
(
'You need to provide a dictionary with keys {question:..., context:...}'
)
raise
KeyError
(
'You need to provide a dictionary with keys {question:..., context:...}'
)
data
[
i
]
=
QuestionAnsweringPipeline
.
create_sample
(
**
item
)
elif
isinstance
(
item
,
SquadExample
):
inputs
[
i
]
=
QuestionAnsweringPipeline
.
create_sample
(
**
item
)
continue
el
se
:
el
if
not
isinstance
(
item
,
SquadExample
)
:
raise
ValueError
(
raise
ValueError
(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.
format
(
'X'
if
'X'
in
kwargs
else
'data'
)
.
format
(
'X'
if
'X'
in
kwargs
else
'data'
)
)
)
inputs
=
data
# Tabular input
# Tabular input
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
...
@@ -588,12 +585,10 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -588,12 +585,10 @@ class QuestionAnsweringPipeline(Pipeline):
# Manage tensor allocation on correct device
# Manage tensor allocation on correct device
with
self
.
device_placement
():
with
self
.
device_placement
():
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
fw_args
=
{
k
:
tf
.
constant
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
fw_args
=
{
k
:
tf
.
constant
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
start
,
end
=
self
.
model
(
fw_args
)
start
,
end
=
self
.
model
(
fw_args
)
start
,
end
=
start
.
numpy
(),
end
.
numpy
()
start
,
end
=
start
.
numpy
(),
end
.
numpy
()
else
:
else
:
import
torch
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# Retrieve the score for the context tokens only (removing question tokens)
# Retrieve the score for the context tokens only (removing question tokens)
fw_args
=
{
k
:
torch
.
tensor
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
fw_args
=
{
k
:
torch
.
tensor
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
...
@@ -812,8 +807,9 @@ def pipeline(task: str, model: Optional = None,
...
@@ -812,8 +807,9 @@ def pipeline(task: str, model: Optional = None,
if
isinstance
(
config
,
str
):
if
isinstance
(
config
,
str
):
config
=
AutoConfig
.
from_pretrained
(
config
)
config
=
AutoConfig
.
from_pretrained
(
config
)
if
allocator
.
__name__
.
startswith
(
'TF'
):
if
isinstance
(
model
,
str
):
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_pt
=
from_pt
)
if
allocator
.
__name__
.
startswith
(
'TF'
):
else
:
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_pt
=
from_pt
)
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_tf
=
from_tf
)
else
:
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_tf
=
from_tf
)
return
task
(
model
,
tokenizer
,
**
kwargs
)
return
task
(
model
,
tokenizer
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment