Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3492a6ec
Commit
3492a6ec
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Addressing Thom's comments.
parent
81a911cc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
37 deletions
+33
-37
transformers/pipelines.py
transformers/pipelines.py
+33
-37
No files found.
transformers/pipelines.py
View file @
3492a6ec
...
@@ -30,6 +30,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer, Pretrai
...
@@ -30,6 +30,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer, Pretrai
SquadExample
,
squad_convert_examples_to_features
,
is_tf_available
,
is_torch_available
,
logger
SquadExample
,
squad_convert_examples_to_features
,
is_tf_available
,
is_torch_available
,
logger
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers
import
TFAutoModel
,
TFAutoModelForSequenceClassification
,
\
from
transformers
import
TFAutoModel
,
TFAutoModelForSequenceClassification
,
\
TFAutoModelForQuestionAnswering
,
TFAutoModelForTokenClassification
TFAutoModelForQuestionAnswering
,
TFAutoModelForTokenClassification
...
@@ -79,9 +80,9 @@ class PipelineDataFormat:
...
@@ -79,9 +80,9 @@ class PipelineDataFormat:
"""
"""
SUPPORTED_FORMATS
=
[
'json'
,
'csv'
,
'pipe'
]
SUPPORTED_FORMATS
=
[
'json'
,
'csv'
,
'pipe'
]
def
__init__
(
self
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
self
.
output
=
output
self
.
output
=
output
self
.
path
=
path
self
.
path
=
input
self
.
column
=
column
.
split
(
','
)
if
column
else
[
''
]
self
.
column
=
column
.
split
(
','
)
if
column
else
[
''
]
self
.
is_multi_columns
=
len
(
self
.
column
)
>
1
self
.
is_multi_columns
=
len
(
self
.
column
)
>
1
...
@@ -92,7 +93,7 @@ class PipelineDataFormat:
...
@@ -92,7 +93,7 @@ class PipelineDataFormat:
if
exists
(
abspath
(
self
.
output
)):
if
exists
(
abspath
(
self
.
output
)):
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
if
path
is
not
None
:
if
input
is
not
None
:
if
not
exists
(
abspath
(
self
.
path
)):
if
not
exists
(
abspath
(
self
.
path
)):
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
...
@@ -136,8 +137,8 @@ class PipelineDataFormat:
...
@@ -136,8 +137,8 @@ class PipelineDataFormat:
class
CsvPipelineDataFormat
(
PipelineDataFormat
):
class
CsvPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
,
path
,
column
)
super
().
__init__
(
output
,
input
,
column
)
def
__iter__
(
self
):
def
__iter__
(
self
):
with
open
(
self
.
path
,
'r'
)
as
f
:
with
open
(
self
.
path
,
'r'
)
as
f
:
...
@@ -157,10 +158,10 @@ class CsvPipelineDataFormat(PipelineDataFormat):
...
@@ -157,10 +158,10 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class
JsonPipelineDataFormat
(
PipelineDataFormat
):
class
JsonPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
,
path
,
column
)
super
().
__init__
(
output
,
input
,
column
)
with
open
(
path
,
'r'
)
as
f
:
with
open
(
input
,
'r'
)
as
f
:
self
.
_entries
=
json
.
load
(
f
)
self
.
_entries
=
json
.
load
(
f
)
def
__iter__
(
self
):
def
__iter__
(
self
):
...
@@ -321,11 +322,9 @@ class Pipeline(_ScikitCompat):
...
@@ -321,11 +322,9 @@ class Pipeline(_ScikitCompat):
Context manager
Context manager
"""
"""
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
with
tf
.
device
(
'/CPU:0'
if
self
.
device
==
-
1
else
'/device:GPU:{}'
.
format
(
self
.
device
)):
with
tf
.
device
(
'/CPU:0'
if
self
.
device
==
-
1
else
'/device:GPU:{}'
.
format
(
self
.
device
)):
yield
yield
else
:
else
:
import
torch
if
self
.
device
>=
0
:
if
self
.
device
>=
0
:
torch
.
cuda
.
set_device
(
self
.
device
)
torch
.
cuda
.
set_device
(
self
.
device
)
...
@@ -358,11 +357,10 @@ class Pipeline(_ScikitCompat):
...
@@ -358,11 +357,10 @@ class Pipeline(_ScikitCompat):
# Encode for forward
# Encode for forward
with
self
.
device_placement
():
with
self
.
device_placement
():
# TODO : Remove this 512 hard-limit
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
,
add_special_tokens
=
True
,
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
max_length
=
512
max_length
=
self
.
tokenizer
.
max_len
)
)
# Filter out features not available on specific models
# Filter out features not available on specific models
...
@@ -379,11 +377,10 @@ class Pipeline(_ScikitCompat):
...
@@ -379,11 +377,10 @@ class Pipeline(_ScikitCompat):
"""
"""
if
is_tf_available
():
if
is_tf_available
():
# TODO trace model
# TODO trace model
predictions
=
self
.
model
(
inputs
)[
0
]
predictions
=
self
.
model
(
inputs
,
training
=
False
)[
0
]
else
:
else
:
import
torch
with
torch
.
no_grad
():
with
torch
.
no_grad
():
predictions
=
self
.
model
(
**
inputs
)[
0
]
predictions
=
self
.
model
(
**
inputs
)
.
cpu
()
[
0
]
return
predictions
.
numpy
()
return
predictions
.
numpy
()
...
@@ -432,19 +429,18 @@ class NerPipeline(Pipeline):
...
@@ -432,19 +429,18 @@ class NerPipeline(Pipeline):
# Manage correct placement of the tensors
# Manage correct placement of the tensors
with
self
.
device_placement
():
with
self
.
device_placement
():
# TODO : Remove this 512 hard-limit
tokens
=
self
.
tokenizer
.
encode_plus
(
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
,
max_length
=
512
max_length
=
self
.
tokenizer
.
max_len
)
)
# Forward
# Forward
if
is_torch_available
():
if
is_tf_available
():
entities
=
self
.
model
(
**
tokens
)[
0
][
0
].
numpy
()
else
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
entities
=
self
.
model
(
**
tokens
)[
0
][
0
].
cpu
().
numpy
()
entities
=
self
.
model
(
**
tokens
)[
0
][
0
].
cpu
().
numpy
()
else
:
entities
=
self
.
model
(
tokens
)[
0
][
0
].
numpy
()
# Normalize scores
# Normalize scores
answer
,
token_start
=
[],
1
answer
,
token_start
=
[],
1
...
@@ -487,25 +483,26 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
...
@@ -487,25 +483,26 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
# Generic compatibility with sklearn and Keras
# Generic compatibility with sklearn and Keras
# Batched data
# Batched data
if
'X'
in
kwargs
or
'data'
in
kwargs
:
if
'X'
in
kwargs
or
'data'
in
kwargs
:
data
=
kwargs
[
'X'
]
if
'X'
in
kwargs
else
kwargs
[
'data'
]
inputs
=
kwargs
[
'X'
]
if
'X'
in
kwargs
else
kwargs
[
'data'
]
if
not
isinstance
(
data
,
list
):
if
isinstance
(
inputs
,
dict
):
data
=
[
data
]
inputs
=
[
inputs
]
else
:
# Copy to avoid overriding arguments
inputs
=
[
i
for
i
in
inputs
]
for
i
,
item
in
enumerate
(
data
):
for
i
,
item
in
enumerate
(
inputs
):
if
isinstance
(
item
,
dict
):
if
isinstance
(
item
,
dict
):
if
any
(
k
not
in
item
for
k
in
[
'question'
,
'context'
]):
if
any
(
k
not
in
item
for
k
in
[
'question'
,
'context'
]):
raise
KeyError
(
'You need to provide a dictionary with keys {question:..., context:...}'
)
raise
KeyError
(
'You need to provide a dictionary with keys {question:..., context:...}'
)
data
[
i
]
=
QuestionAnsweringPipeline
.
create_sample
(
**
item
)
elif
isinstance
(
item
,
SquadExample
):
inputs
[
i
]
=
QuestionAnsweringPipeline
.
create_sample
(
**
item
)
continue
el
se
:
el
if
not
isinstance
(
item
,
SquadExample
)
:
raise
ValueError
(
raise
ValueError
(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.
format
(
'X'
if
'X'
in
kwargs
else
'data'
)
.
format
(
'X'
if
'X'
in
kwargs
else
'data'
)
)
)
inputs
=
data
# Tabular input
# Tabular input
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
elif
'question'
in
kwargs
and
'context'
in
kwargs
:
...
@@ -588,12 +585,10 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -588,12 +585,10 @@ class QuestionAnsweringPipeline(Pipeline):
# Manage tensor allocation on correct device
# Manage tensor allocation on correct device
with
self
.
device_placement
():
with
self
.
device_placement
():
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
fw_args
=
{
k
:
tf
.
constant
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
fw_args
=
{
k
:
tf
.
constant
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
start
,
end
=
self
.
model
(
fw_args
)
start
,
end
=
self
.
model
(
fw_args
)
start
,
end
=
start
.
numpy
(),
end
.
numpy
()
start
,
end
=
start
.
numpy
(),
end
.
numpy
()
else
:
else
:
import
torch
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# Retrieve the score for the context tokens only (removing question tokens)
# Retrieve the score for the context tokens only (removing question tokens)
fw_args
=
{
k
:
torch
.
tensor
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
fw_args
=
{
k
:
torch
.
tensor
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
...
@@ -812,6 +807,7 @@ def pipeline(task: str, model: Optional = None,
...
@@ -812,6 +807,7 @@ def pipeline(task: str, model: Optional = None,
if
isinstance
(
config
,
str
):
if
isinstance
(
config
,
str
):
config
=
AutoConfig
.
from_pretrained
(
config
)
config
=
AutoConfig
.
from_pretrained
(
config
)
if
isinstance
(
model
,
str
):
if
allocator
.
__name__
.
startswith
(
'TF'
):
if
allocator
.
__name__
.
startswith
(
'TF'
):
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_pt
=
from_pt
)
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_pt
=
from_pt
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment