Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
baca8fa8
Unverified
Commit
baca8fa8
authored
Apr 16, 2020
by
Patrick von Platen
Committed by
GitHub
Apr 16, 2020
Browse files
clean pipelines (#3795)
parent
38f7461d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
52 deletions
+21
-52
src/transformers/pipelines.py
src/transformers/pipelines.py
+3
-31
tests/test_pipelines.py
tests/test_pipelines.py
+18
-21
No files found.
src/transformers/pipelines.py
View file @
baca8fa8
...
...
@@ -23,17 +23,12 @@ import sys
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
os.path
import
abspath
,
exists
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AutoConfig
from
.configuration_bart
import
BartConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_roberta
import
RobertaConfig
from
.configuration_t5
import
T5Config
from
.configuration_utils
import
PretrainedConfig
from
.configuration_xlm
import
XLMConfig
from
.data
import
SquadExample
,
squad_convert_examples_to_features
from
.file_utils
import
is_tf_available
,
is_torch_available
from
.modelcard
import
ModelCard
...
...
@@ -423,27 +418,6 @@ class Pipeline(_ScikitCompat):
"""
return
{
name
:
tensor
.
to
(
self
.
device
)
for
name
,
tensor
in
inputs
.
items
()}
def
inputs_for_model
(
self
,
features
:
Union
[
dict
,
List
[
dict
]])
->
Dict
:
"""
Generates the input dictionary with model-specific parameters.
Returns:
dict holding all the required parameters for model's forward
"""
args
=
[
"input_ids"
,
"attention_mask"
]
if
not
isinstance
(
self
.
model
.
config
,
(
DistilBertConfig
,
XLMConfig
,
RobertaConfig
,
BartConfig
,
T5Config
)):
args
+=
[
"token_type_ids"
]
# PR #1548 (CLI) There is an issue with attention_mask
# if 'xlnet' in model_type or 'xlm' in model_type:
# args += ['cls_index', 'p_mask']
if
isinstance
(
features
,
dict
):
return
{
k
:
features
[
k
]
for
k
in
args
}
else
:
return
{
k
:
[
feature
[
k
]
for
feature
in
features
]
for
k
in
args
}
def
_parse_and_tokenize
(
self
,
*
texts
,
pad_to_max_length
=
False
,
**
kwargs
):
"""
Parse arguments and tokenize
...
...
@@ -458,9 +432,6 @@ class Pipeline(_ScikitCompat):
pad_to_max_length
=
pad_to_max_length
,
)
# Filter out features not available on specific models
# inputs = self.inputs_for_model(inputs)
return
inputs
def
__call__
(
self
,
*
texts
,
**
kwargs
):
...
...
@@ -995,7 +966,8 @@ class QuestionAnsweringPipeline(Pipeline):
]
all_answers
=
[]
for
features
,
example
in
zip
(
features_list
,
examples
):
fw_args
=
self
.
inputs_for_model
([
f
.
__dict__
for
f
in
features
])
model_input_names
=
self
.
tokenizer
.
model_input_names
+
[
"input_ids"
]
fw_args
=
{
k
:
[
feature
.
__dict__
[
k
]
for
feature
in
features
]
for
k
in
model_input_names
}
# Manage tensor allocation on correct device
with
self
.
device_placement
():
...
...
tests/test_pipelines.py
View file @
baca8fa8
...
...
@@ -2,26 +2,19 @@ import unittest
from
typing
import
Iterable
,
List
,
Optional
from
transformers
import
pipeline
from
transformers.pipelines
import
(
FeatureExtractionPipeline
,
FillMaskPipeline
,
NerPipeline
,
Pipeline
,
QuestionAnsweringPipeline
,
TextClassificationPipeline
,
)
from
transformers.pipelines
import
Pipeline
from
.utils
import
require_tf
,
require_torch
,
slow
QA_FINETUNED_MODELS
=
[
((
"bert-base-uncased"
,
{
"use_fast"
:
False
}),
"bert-large-uncased-whole-word-masking-finetuned-squad"
,
None
),
((
"bert-base-cased"
,
{
"use_fast"
:
False
}),
"distilbert-base-cased-distilled-squad"
,
None
),
((
"
distil
bert-base-cased
-distilled-squad
"
,
{
"use_fast"
:
False
}),
"distilbert-base-cased-distilled-squad"
,
None
),
]
TF_QA_FINETUNED_MODELS
=
[
((
"bert-base-uncased"
,
{
"use_fast"
:
False
}),
"bert-large-uncased-whole-word-masking-finetuned-squad"
,
None
),
((
"bert-base-cased"
,
{
"use_fast"
:
False
}),
"distilbert-base-cased-distilled-squad"
,
None
),
((
"
distil
bert-base-cased
-distilled-squad
"
,
{
"use_fast"
:
False
}),
"distilbert-base-cased-distilled-squad"
,
None
),
]
TF_NER_FINETUNED_MODELS
=
{
...
...
@@ -369,25 +362,29 @@ class MultiColumnInputTestCase(unittest.TestCase):
class
PipelineCommonTests
(
unittest
.
TestCase
):
pipelines
=
(
NerPipeline
,
FeatureExtractionPipeline
,
QuestionAnsweringPipeline
,
FillMaskPipeline
,
TextClassificationPipeline
,
"ner"
,
"feature-extraction"
,
"question-answering"
,
"fill-mask"
,
"summarization"
,
"sentiment-analysis"
,
"translation_en_to_fr"
,
"translation_en_to_de"
,
"translation_en_to_ro"
,
)
@
slow
@
require_tf
def
test_tf_defaults
(
self
):
# Test that pipelines can be correctly loaded without any argument
for
default_pipeline
in
self
.
pipelines
:
with
self
.
subTest
(
msg
=
"Testing Torch defaults with PyTorch and {}"
.
format
(
default_pipeline
.
task
)):
default_
pipeline
(
framework
=
"tf"
)
for
task
in
self
.
pipelines
:
with
self
.
subTest
(
msg
=
"Testing Torch defaults with PyTorch and {}"
.
format
(
task
)):
pipeline
(
task
,
framework
=
"tf"
)
@
slow
@
require_torch
def
test_pt_defaults
(
self
):
# Test that pipelines can be correctly loaded without any argument
for
default_pipeline
in
self
.
pipelines
:
with
self
.
subTest
(
msg
=
"Testing Torch defaults with PyTorch and {}"
.
format
(
default_pipeline
.
task
)):
default_
pipeline
(
framework
=
"pt"
)
for
task
in
self
.
pipelines
:
with
self
.
subTest
(
msg
=
"Testing Torch defaults with PyTorch and {}"
.
format
(
task
)):
pipeline
(
task
,
framework
=
"pt"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment