Unverified Commit baca8fa8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

clean pipelines (#3795)

parent 38f7461d
......@@ -23,17 +23,12 @@ import sys
from abc import ABC, abstractmethod
from contextlib import contextmanager
from os.path import abspath, exists
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
from .configuration_bart import BartConfig
from .configuration_distilbert import DistilBertConfig
from .configuration_roberta import RobertaConfig
from .configuration_t5 import T5Config
from .configuration_utils import PretrainedConfig
from .configuration_xlm import XLMConfig
from .data import SquadExample, squad_convert_examples_to_features
from .file_utils import is_tf_available, is_torch_available
from .modelcard import ModelCard
......@@ -423,27 +418,6 @@ class Pipeline(_ScikitCompat):
"""
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict:
"""
Generates the input dictionary with model-specific parameters.
Returns:
dict holding all the required parameters for model's forward
"""
args = ["input_ids", "attention_mask"]
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig, T5Config)):
args += ["token_type_ids"]
# PR #1548 (CLI) There is an issue with attention_mask
# if 'xlnet' in model_type or 'xlm' in model_type:
# args += ['cls_index', 'p_mask']
if isinstance(features, dict):
return {k: features[k] for k in args}
else:
return {k: [feature[k] for feature in features] for k in args}
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs):
"""
Parse arguments and tokenize
......@@ -458,9 +432,6 @@ class Pipeline(_ScikitCompat):
pad_to_max_length=pad_to_max_length,
)
# Filter out features not available on specific models
# inputs = self.inputs_for_model(inputs)
return inputs
def __call__(self, *texts, **kwargs):
......@@ -995,7 +966,8 @@ class QuestionAnsweringPipeline(Pipeline):
]
all_answers = []
for features, example in zip(features_list, examples):
fw_args = self.inputs_for_model([f.__dict__ for f in features])
model_input_names = self.tokenizer.model_input_names + ["input_ids"]
fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
# Manage tensor allocation on correct device
with self.device_placement():
......
......@@ -2,26 +2,19 @@ import unittest
from typing import Iterable, List, Optional
from transformers import pipeline
from transformers.pipelines import (
FeatureExtractionPipeline,
FillMaskPipeline,
NerPipeline,
Pipeline,
QuestionAnsweringPipeline,
TextClassificationPipeline,
)
from transformers.pipelines import Pipeline
from .utils import require_tf, require_torch, slow
QA_FINETUNED_MODELS = [
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
]
TF_QA_FINETUNED_MODELS = [
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
]
TF_NER_FINETUNED_MODELS = {
......@@ -369,25 +362,29 @@ class MultiColumnInputTestCase(unittest.TestCase):
class PipelineCommonTests(unittest.TestCase):
pipelines = (
NerPipeline,
FeatureExtractionPipeline,
QuestionAnsweringPipeline,
FillMaskPipeline,
TextClassificationPipeline,
"ner",
"feature-extraction",
"question-answering",
"fill-mask",
"summarization",
"sentiment-analysis",
"translation_en_to_fr",
"translation_en_to_de",
"translation_en_to_ro",
)
@slow
@require_tf
def test_tf_defaults(self):
# Test that pipelines can be correctly loaded without any argument
for default_pipeline in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
default_pipeline(framework="tf")
for task in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
pipeline(task, framework="tf")
@slow
@require_torch
def test_pt_defaults(self):
# Test that pipelines can be correctly loaded without any argument
for default_pipeline in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
default_pipeline(framework="pt")
for task in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
pipeline(task, framework="pt")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment