pipelines.py 110 KB
Newer Older
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

Morgan Funtowicz's avatar
Morgan Funtowicz committed
16

17
18
import csv
import json
Morgan Funtowicz's avatar
Morgan Funtowicz committed
19
import os
20
import pickle
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import sys
22
import uuid
Morgan Funtowicz's avatar
Morgan Funtowicz committed
23
from abc import ABC, abstractmethod
24
from contextlib import contextmanager
25
from itertools import chain
26
from os.path import abspath, exists
27
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
28
from uuid import UUID
Morgan Funtowicz's avatar
Morgan Funtowicz committed
29
30
31

import numpy as np

32
from .configuration_auto import AutoConfig
33
34
from .configuration_utils import PretrainedConfig
from .data import SquadExample, squad_convert_examples_to_features
Sylvain Gugger's avatar
Sylvain Gugger committed
35
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available
36
37
38
39
from .modelcard import ModelCard
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
40
from .tokenization_utils_base import BatchEncoding, PaddingStrategy
Lysandre Debut's avatar
Lysandre Debut committed
41
from .utils import logging
Morgan Funtowicz's avatar
Morgan Funtowicz committed
42

Aymeric Augustin's avatar
Aymeric Augustin committed
43

Morgan Funtowicz's avatar
Morgan Funtowicz committed
44
if is_tf_available():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
45
    import tensorflow as tf
46

47
    from .modeling_tf_auto import (
48
49
50
51
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
        TF_MODEL_WITH_LM_HEAD_MAPPING,
52
        TFAutoModel,
53
        TFAutoModelForCausalLM,
54
        TFAutoModelForQuestionAnswering,
55
        TFAutoModelForSequenceClassification,
56
        TFAutoModelForTokenClassification,
Julien Chaumond's avatar
Julien Chaumond committed
57
        TFAutoModelWithLMHead,
58
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
59
60
61

if is_torch_available():
    import torch
62

63
    from .modeling_auto import (
64
65
66
67
68
        MODEL_FOR_MASKED_LM_MAPPING,
        MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
        MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
69
        AutoModel,
70
71
        AutoModelForCausalLM,
        AutoModelForMaskedLM,
72
73
74
75
        AutoModelForQuestionAnswering,
        AutoModelForSeq2SeqLM,
        AutoModelForSequenceClassification,
        AutoModelForTokenClassification,
76
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
77

78
79
if TYPE_CHECKING:
    from .modeling_tf_utils import TFPreTrainedModel
80
    from .modeling_utils import PreTrainedModel
81

Morgan Funtowicz's avatar
Morgan Funtowicz committed
82

Lysandre Debut's avatar
Lysandre Debut committed
83
logger = logging.get_logger(__name__)
84

85

thomwolf's avatar
thomwolf committed
86
def get_framework(model=None):
Sylvain Gugger's avatar
Sylvain Gugger committed
87
88
89
90
91
92
93
    """
    Select framework (TensorFlow or PyTorch) to use.

    Args:
        model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`, `optional`):
            If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
            the model name). If no specific model is provided, defaults to using PyTorch.
94
    """
thomwolf's avatar
thomwolf committed
95
    if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
Julien Chaumond's avatar
Julien Chaumond committed
96
        # Both framework are available but the user supplied a model class instance.
thomwolf's avatar
thomwolf committed
97
        # Try to guess which framework to use from the model classname
98
        framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
99
    elif not is_tf_available() and not is_torch_available():
Aymeric Augustin's avatar
Aymeric Augustin committed
100
        raise RuntimeError(
101
102
103
104
            "At least one of TensorFlow 2.0 or PyTorch should be installed. "
            "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
            "To install PyTorch, read the instructions at https://pytorch.org/."
        )
105
    else:
106
        # framework = 'tf' if is_tf_available() else 'pt'
107
        framework = "pt" if is_torch_available() else "tf"
thomwolf's avatar
thomwolf committed
108
109
    return framework

110

111
112
class PipelineException(Exception):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
113
114
115
116
117
118
    Raised by a :class:`~transformers.Pipeline` when handling __call__.

    Args:
        task (:obj:`str`): The task of the pipeline.
        model (:obj:`str`): The model used by the pipeline.
        reason (:obj:`str`): The error message to display.
119
120
121
122
123
124
125
126
127
    """

    def __init__(self, task: str, model: str, reason: str):
        super().__init__(reason)

        self.task = task
        self.model = model


128
129
class ArgumentHandler(ABC):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
130
    Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
131
    """
132

133
134
135
    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplementedError()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
136
137


138
139
class DefaultArgumentHandler(ArgumentHandler):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
140
    Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`.
141
    """
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    @staticmethod
    def handle_kwargs(kwargs: Dict) -> List:
        if len(kwargs) == 1:
            output = list(kwargs.values())
        else:
            output = list(chain(kwargs.values()))

        return DefaultArgumentHandler.handle_args(output)

    @staticmethod
    def handle_args(args: Sequence[Any]) -> List[str]:

        # Only one argument, let's do case by case
        if len(args) == 1:
            if isinstance(args[0], str):
158
                return [args[0]]
159
160
161
162
163
164
            elif not isinstance(args[0], list):
                return list(args)
            else:
                return args[0]

        # Multiple arguments (x1, x2, ...)
165
        elif len(args) > 1:
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            if all([isinstance(arg, str) for arg in args]):
                return list(args)

            # If not instance of list, then it should instance of iterable
            elif isinstance(args, Iterable):
                return list(chain.from_iterable(chain(args)))
            else:
                raise ValueError(
                    "Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
                )
        else:
            return []

    def __call__(self, *args, **kwargs):
        if len(kwargs) > 0 and len(args) > 0:
            raise ValueError("Pipeline cannot handle mixed args and kwargs")

        if len(kwargs) > 0:
            return DefaultArgumentHandler.handle_kwargs(kwargs)
        else:
            return DefaultArgumentHandler.handle_args(args)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
187
188


189
class PipelineDataFormat:
190
191
192
    """
    Base class for all the pipeline supported data format both for reading and writing.
    Supported data formats currently includes:
Sylvain Gugger's avatar
Sylvain Gugger committed
193
194
195
    - JSON
    - CSV
    - stdin/stdout (pipe)
196

Sylvain Gugger's avatar
Sylvain Gugger committed
197
198
199
200
201
202
203
204
205
    :obj:`PipelineDataFormat` also includes some utilities to work with multi-columns like mapping from datasets
    columns to pipelines keyword arguments through the :obj:`dataset_kwarg_1=dataset_column_1` format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
206
    """
207
208

    SUPPORTED_FORMATS = ["json", "csv", "pipe"]
209

210
    def __init__(
Lysandre's avatar
Lysandre committed
211
212
213
214
215
        self,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite: bool = False,
216
    ):
thomwolf's avatar
thomwolf committed
217
218
        self.output_path = output_path
        self.input_path = input_path
219
        self.column = column.split(",") if column is not None else [""]
220
221
222
        self.is_multi_columns = len(self.column) > 1

        if self.is_multi_columns:
223
            self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column]
224

thomwolf's avatar
thomwolf committed
225
        if output_path is not None and not overwrite:
thomwolf's avatar
thomwolf committed
226
            if exists(abspath(self.output_path)):
227
                raise OSError("{} already exists on disk".format(self.output_path))
228

thomwolf's avatar
thomwolf committed
229
230
        if input_path is not None:
            if not exists(abspath(self.input_path)):
231
                raise OSError("{} doesnt exist on disk".format(self.input_path))
232
233
234
235
236
237

    @abstractmethod
    def __iter__(self):
        raise NotImplementedError()

    @abstractmethod
Sylvain Gugger's avatar
Sylvain Gugger committed
238
    def save(self, data: Union[dict, List[dict]]):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
239
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
240
241
242
243
244
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
245
        """
246
247
        raise NotImplementedError()

248
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
249
250
        """
        Save the provided data object as a pickle-formatted binary data on the disk.
Sylvain Gugger's avatar
Sylvain Gugger committed
251
252
253
254
255
256

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.

        Returns:
            :obj:`str`: Path where the data has been saved.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
257
        """
thomwolf's avatar
thomwolf committed
258
        path, _ = os.path.splitext(self.output_path)
259
        binary_path = os.path.extsep.join((path, "pickle"))
260

261
        with open(binary_path, "wb+") as f_output:
262
263
264
265
            pickle.dump(data, f_output)

        return binary_path

266
    @staticmethod
267
    def from_str(
Lysandre's avatar
Lysandre committed
268
269
270
271
272
        format: str,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite=False,
Sylvain Gugger's avatar
Sylvain Gugger committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    ) -> "PipelineDataFormat":
        """
        Creates an instance of the right subclass of :class:`~transformers.pipelines.PipelineDataFormat` depending
        on :obj:`format`.

        Args:
            format: (:obj:`str`):
                The format of the desired pipeline. Acceptable values are :obj:`"json"`, :obj:`"csv"` or :obj:`"pipe"`.
            output_path (:obj:`str`, `optional`):
                Where to save the outgoing data.
            input_path (:obj:`str`, `optional`):
                Where to look for the input data.
            column (:obj:`str`, `optional`):
                The column to read.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to overwrite the :obj:`output_path`.

        Returns:
            :class:`~transformers.pipelines.PipelineDataFormat`: The proper data format.
        """
293
        if format == "json":
thomwolf's avatar
thomwolf committed
294
            return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
295
        elif format == "csv":
thomwolf's avatar
thomwolf committed
296
            return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
297
        elif format == "pipe":
thomwolf's avatar
thomwolf committed
298
            return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
299
        else:
300
            raise KeyError("Unknown reader {} (Available reader are json/csv/pipe)".format(format))
301
302
303


class CsvPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
304
305
306
307
308
309
310
311
312
313
314
    """
    Support for pipelines using CSV data format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

315
    def __init__(
Lysandre's avatar
Lysandre committed
316
317
318
319
320
        self,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite=False,
321
    ):
thomwolf's avatar
thomwolf committed
322
        super().__init__(output_path, input_path, column, overwrite=overwrite)
323
324

    def __iter__(self):
325
        with open(self.input_path, "r") as f:
326
327
328
329
330
            reader = csv.DictReader(f)
            for row in reader:
                if self.is_multi_columns:
                    yield {k: row[c] for k, c in self.column}
                else:
331
                    yield row[self.column[0]]
332
333

    def save(self, data: List[dict]):
Sylvain Gugger's avatar
Sylvain Gugger committed
334
335
336
337
338
339
340
        """
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`List[dict]`): The data to store.
        """
341
        with open(self.output_path, "w") as f:
342
343
344
345
346
347
348
            if len(data) > 0:
                writer = csv.DictWriter(f, list(data[0].keys()))
                writer.writeheader()
                writer.writerows(data)


class JsonPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
349
350
351
352
353
354
355
356
357
358
359
    """
    Support for pipelines using JSON file format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

360
    def __init__(
Lysandre's avatar
Lysandre committed
361
362
363
364
365
        self,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite=False,
366
    ):
thomwolf's avatar
thomwolf committed
367
        super().__init__(output_path, input_path, column, overwrite=overwrite)
368

369
        with open(input_path, "r") as f:
370
371
372
373
374
375
376
            self._entries = json.load(f)

    def __iter__(self):
        for entry in self._entries:
            if self.is_multi_columns:
                yield {k: entry[c] for k, c in self.column}
            else:
377
                yield entry[self.column[0]]
378
379

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
380
381
382
383
384
385
        """
        Save the provided data object in a json file.

        Args:
            data (:obj:`dict`): The data to store.
        """
386
        with open(self.output_path, "w") as f:
387
388
389
            json.dump(data, f)


Morgan Funtowicz's avatar
Morgan Funtowicz committed
390
391
392
393
394
395
class PipedPipelineDataFormat(PipelineDataFormat):
    """
    Read data from piped input to the python process.
    For multi columns data, columns should separated by \t

    If columns are provided, then the output will be a dictionary with {column_x: value_x}
Sylvain Gugger's avatar
Sylvain Gugger committed
396
397
398
399
400
401
402

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
403
    """
404

Morgan Funtowicz's avatar
Morgan Funtowicz committed
405
406
407
    def __iter__(self):
        for line in sys.stdin:
            # Split for multi-columns
408
            if "\t" in line:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
409

410
                line = line.split("\t")
Morgan Funtowicz's avatar
Morgan Funtowicz committed
411
412
413
414
415
416
417
418
419
420
421
                if self.column:
                    # Dictionary to map arguments
                    yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
                else:
                    yield tuple(line)

            # No dictionary to map arguments
            else:
                yield line

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
422
423
424
425
426
427
        """
        Print the data.

        Args:
            data (:obj:`dict`): The data to store.
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
428
429
        print(data)

430
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
thomwolf's avatar
thomwolf committed
431
        if self.output_path is None:
432
            raise KeyError(
433
434
                "When using piped input on pipeline outputting large object requires an output file path. "
                "Please provide such output path through --output argument."
435
436
437
438
            )

        return super().save_binary(data)

Morgan Funtowicz's avatar
Morgan Funtowicz committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453

class _ScikitCompat(ABC):
    """
    Interface layer for the Scikit and Keras compatibility.
    """

    @abstractmethod
    def transform(self, X):
        raise NotImplementedError()

    @abstractmethod
    def predict(self, X):
        raise NotImplementedError()


Sylvain Gugger's avatar
Sylvain Gugger committed
454
PIPELINE_INIT_ARGS = r"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
455
    Arguments:
456
457
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
458
459
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
460
461
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
462
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
463
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
464
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
465
466
467
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
468
469

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
470
471
472
473
474
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
475
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
476
477
        device (:obj:`int`, `optional`, defaults to -1):
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model
Morgan Funtowicz's avatar
Morgan Funtowicz committed
478
            on the associated CUDA device id.
Lysandre Debut's avatar
Lysandre Debut committed
479
        binary_output (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
480
481
            Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
482

Lysandre Debut's avatar
Lysandre Debut committed
483

Sylvain Gugger's avatar
Sylvain Gugger committed
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
@add_end_docstrings(PIPELINE_INIT_ARGS)
class Pipeline(_ScikitCompat):
    """
    The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
    different pipelines.

    Base class implementing pipelined operations.
    Pipeline workflow is defined as a sequence of the following operations:

        Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output

    Pipeline supports running on CPU or GPU through the device argument (see below).

    Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` )
    output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
    provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the
    pickle format.
501
    """
thomwolf's avatar
thomwolf committed
502
503
504

    default_input_names = None

505
506
    def __init__(
        self,
507
508
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
509
        modelcard: Optional[ModelCard] = None,
510
        framework: Optional[str] = None,
511
        task: str = "",
512
513
514
515
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
    ):
516

thomwolf's avatar
thomwolf committed
517
        if framework is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
518
            framework = get_framework(model)
thomwolf's avatar
thomwolf committed
519

520
        self.task = task
521
522
        self.model = model
        self.tokenizer = tokenizer
523
        self.modelcard = modelcard
thomwolf's avatar
thomwolf committed
524
        self.framework = framework
525
        self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
526
        self.binary_output = binary_output
527
528
        self._args_parser = args_parser or DefaultArgumentHandler()

529
        # Special handling
530
531
        if self.framework == "pt" and self.device.type == "cuda":
            self.model = self.model.to(self.device)
532

533
534
535
536
537
        # Update config with task specific parameters
        task_specific_params = self.model.config.task_specific_params
        if task_specific_params is not None and task in task_specific_params:
            self.model.config.update(task_specific_params.get(task))

Sylvain Gugger's avatar
Sylvain Gugger committed
538
    def save_pretrained(self, save_directory: str):
539
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
540
541
542
543
544
        Save the pipeline's model and tokenizer.

        Args:
            save_directory (:obj:`str`):
                A path to the directory where to saved. It will be created if it doesn't exist.
545
        """
546
547
        if os.path.isfile(save_directory):
            logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
548
            return
549
        os.makedirs(save_directory, exist_ok=True)
550
551
552

        self.model.save_pretrained(save_directory)
        self.tokenizer.save_pretrained(save_directory)
553
554
        if self.modelcard is not None:
            self.modelcard.save_pretrained(save_directory)
555
556

    def transform(self, X):
557
558
559
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
560
561
562
        return self(X=X)

    def predict(self, X):
563
564
565
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
566
        return self(X=X)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
567

568
569
    @contextmanager
    def device_placement(self):
570
571
        """
        Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
Sylvain Gugger's avatar
Sylvain Gugger committed
572

573
574
        Returns:
            Context manager
Sylvain Gugger's avatar
Sylvain Gugger committed
575
576
577
578
579
580
581
582

        Examples::

            # Explicitly ask for tensor allocation on CUDA device :0
            pipe = pipeline(..., device=0)
            with pipe.device_placement():
                # Every framework specific tensor allocation will be done on the request device
                output = pipe(...)
583
        """
584
585
        if self.framework == "tf":
            with tf.device("/CPU:0" if self.device == -1 else "/device:GPU:{}".format(self.device)):
586
587
                yield
        else:
588
            if self.device.type == "cuda":
589
                torch.cuda.set_device(self.device)
590

591
            yield
592

593
594
595
    def ensure_tensor_on_device(self, **inputs):
        """
        Ensure PyTorch tensors are on the specified device.
Sylvain Gugger's avatar
Sylvain Gugger committed
596
597
598
599
600
601

        Args:
            inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.

        Return:
            :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
602
603
604
        """
        return {name: tensor.to(self.device) for name, tensor in inputs.items()}

Sylvain Gugger's avatar
Sylvain Gugger committed
605
    def check_model_type(self, supported_models: Union[List[str], dict]):
606
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
607
608
609
610
611
        Check if the model class is in supported by the pipeline.

        Args:
            supported_models (:obj:`List[str]` or :obj:`dict`):
                The list of models supported by the pipeline, or a dictionary with model class values.
612
613
614
615
616
617
618
619
620
621
        """
        if not isinstance(supported_models, list):  # Create from a model mapping
            supported_models = [item[1].__name__ for item in supported_models.items()]
        if self.model.__class__.__name__ not in supported_models:
            raise PipelineException(
                self.task,
                self.model.base_model_prefix,
                f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
            )

622
    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
623
624
625
        """
        Parse arguments and tokenize
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
626
        # Parse arguments
627
        inputs = self._args_parser(*args, **kwargs)
628
        inputs = self.tokenizer(
Lysandre's avatar
Lysandre committed
629
630
631
632
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
633
        )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
634

Julien Chaumond's avatar
Julien Chaumond committed
635
636
        return inputs

637
638
    def __call__(self, *args, **kwargs):
        inputs = self._parse_and_tokenize(*args, **kwargs)
639
        return self._forward(inputs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
640

Julien Chaumond's avatar
Julien Chaumond committed
641
    def _forward(self, inputs, return_tensors=False):
642
643
644
645
        """
        Internal framework specific forward dispatching.
        Args:
            inputs: dict holding all the keyworded arguments for required by the model forward method.
Julien Chaumond's avatar
Julien Chaumond committed
646
            return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array.
647
648
649
        Returns:
            Numpy array
        """
650
651
652
653
        # Encode for forward
        with self.device_placement():
            if self.framework == "tf":
                # TODO trace model
Funtowicz Morgan's avatar
Funtowicz Morgan committed
654
                predictions = self.model(inputs.data, training=False)[0]
655
656
657
658
            else:
                with torch.no_grad():
                    inputs = self.ensure_tensor_on_device(**inputs)
                    predictions = self.model(**inputs)[0].cpu()
659

Julien Chaumond's avatar
Julien Chaumond committed
660
661
662
663
        if return_tensors:
            return predictions
        else:
            return predictions.numpy()
664
665


Sylvain Gugger's avatar
Sylvain Gugger committed
666
# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output`
667
class FeatureExtractionPipeline(Pipeline):
668
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
669
670
    Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
    transformer, which can be used as features in downstream tasks.
Lysandre Debut's avatar
Lysandre Debut committed
671

Sylvain Gugger's avatar
Sylvain Gugger committed
672
673
    This feature extraction pipeline can currently be loaded from :func:`~transformers.pipeline` using the task
    identifier: :obj:`"feature-extraction"`.
Lysandre Debut's avatar
Lysandre Debut committed
674
675
676
677
678

    All models may be used for this pipeline. See a list of all models, including community-contributed models on
    `huggingface.co/models <https://huggingface.co/models>`__.

    Arguments:
679
680
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
681
682
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
683
684
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
685
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
686
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
687
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
688
689
690
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
691
692

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
693
694
695
696
697
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
698
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
699
700
        device (:obj:`int`, `optional`, defaults to -1):
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model
Lysandre Debut's avatar
Lysandre Debut committed
701
            on the associated CUDA device id.
702
    """
703

704
705
    def __init__(
        self,
706
707
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
708
        modelcard: Optional[ModelCard] = None,
709
710
711
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
712
        task: str = "",
713
714
715
716
717
718
719
720
721
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
722
            task=task,
723
        )
724

725
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
726
727
728
729
730
731
732
733
734
        """
        Extract the features of the input(s).

        Args:
            args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of texts) to get the features of.

        Return:
            A nested list of :obj:`float`: The features computed by the model.
        """
735
        return super().__call__(*args, **kwargs).tolist()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
736
737


Sylvain Gugger's avatar
Sylvain Gugger committed
738
@add_end_docstrings(PIPELINE_INIT_ARGS)
739
740
class TextGenerationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
741
742
    Language generation pipeline using any :obj:`ModelWithLMHead`. This pipeline predicts the words that will follow a
    specified text prompt.
743

Sylvain Gugger's avatar
Sylvain Gugger committed
744
745
    This language generation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"text-generation"`.
746

Sylvain Gugger's avatar
Sylvain Gugger committed
747
748
    The models that this pipeline can use are models that have been trained with an autoregressive language modeling
    objective, which includes the uni-directional models in the library (e.g. gpt2).
749
750
751
752
753
754
755
    See the list of available community models on
    `huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
    """

    # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
    # in https://github.com/rusiaaman/XLNet-gen#methodology
    # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
756

757
758
759
760
761
762
763
764
765
    PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
    (except for Alexei and Maria) are discovered.
    The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
    remainder of the story. 1883 Western Siberia,
    a young Grigori Rasputin is asked by his father and a group of men to perform magic.
    Rasputin has a vision and denounces one of the men as a horse thief. Although his
    father initially slaps him for making such an accusation, Rasputin watches as the
    man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
    the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
766
    with people, even a bishop, begging for his blessing. """
767

768
769
770
771
772
773
774
775
776
777
778
779
780
781
    ALLOWED_MODELS = [
        "XLNetLMHeadModel",
        "TransfoXLLMHeadModel",
        "ReformerModelWithLMHead",
        "GPT2LMHeadModel",
        "OpenAIGPTLMHeadModel",
        "CTRLLMHeadModel",
        "TFXLNetLMHeadModel",
        "TFTransfoXLLMHeadModel",
        "TFGPT2LMHeadModel",
        "TFOpenAIGPTLMHeadModel",
        "TFCTRLLMHeadModel",
    ]

782
783
784
785
786
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.check_model_type(self.ALLOWED_MODELS)

787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
    # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments

    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
        """
        Parse arguments and tokenize
        """
        # Parse arguments
        if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
            tokenizer_kwargs = {"add_space_before_punct_symbol": True}
        else:
            tokenizer_kwargs = {}
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer(
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            **tokenizer_kwargs,
        )

        return inputs

809
    def __call__(
810
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
811
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
        """
        Complete the prompt(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several prompts (or one list of prompts) to complete.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
831

Sylvain Gugger's avatar
Sylvain Gugger committed
832
833
834
835
            - **generated_text** (:obj:`str`, present when ``return_text=True``) -- The generated text.
            - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the generated text.
        """
836
        text_inputs = self._args_parser(*args)
837
838
839
840
841

        results = []
        for prompt_text in text_inputs:
            # Manage correct placement of the tensors
            with self.device_placement():
842
843
844
845
846
847
                if self.model.__class__.__name__ in [
                    "XLNetLMHeadModel",
                    "TransfoXLLMHeadModel",
                    "TFXLNetLMHeadModel",
                    "TFTransfoXLLMHeadModel",
                ]:
848
849
850
851
852
853
854
855
856
857
                    # For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
                    padding_text = self.PADDING_TEXT + self.tokenizer.eos_token
                    padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False)
                    # This impacts max_length and min_length argument that need adjusting.
                    padding_length = padding["input_ids"].shape[-1]
                    if "max_length" in generate_kwargs and generate_kwargs["max_length"] is not None:
                        generate_kwargs["max_length"] += padding_length
                    if "min_length" in generate_kwargs and generate_kwargs["min_length"] is not None:
                        generate_kwargs["min_length"] += padding_length

858
                    inputs = self._parse_and_tokenize(
859
                        padding_text + prompt_text, padding=False, add_special_tokens=False
860
                    )
861
                else:
862
                    inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False)
863

864
865
866
867
868
869
                # set input_ids to None to allow empty prompt
                if inputs["input_ids"].shape[-1] == 0:
                    inputs["input_ids"] = None
                    inputs["attention_mask"] = None

                if self.framework == "pt" and inputs["input_ids"] is not None:
870
871
872
873
874
875
                    inputs = self.ensure_tensor_on_device(**inputs)

                input_ids = inputs["input_ids"]

                # Ensure that batch size = 1 (batch generation not allowed for now)
                assert (
876
                    input_ids is None or input_ids.shape[0] == 1
877
878
879
880
881
882
                ), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."

                output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs)  # BS x SL

            result = []
            for generated_sequence in output_sequences:
883
884
                if self.framework == "pt" and generated_sequence is not None:
                    generated_sequence = generated_sequence.cpu()
885
                generated_sequence = generated_sequence.numpy().tolist()
886
887
888
889
890
891
892
893
894
895
896
897
                record = {}
                if return_tensors:
                    record["generated_token_ids"] = generated_sequence
                if return_text:
                    # Decode text
                    text = self.tokenizer.decode(
                        generated_sequence,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )

                    # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
898
899
900
901
902
903
904
905
906
907
908
909
                    if input_ids is None:
                        prompt_length = 0
                    else:
                        prompt_length = len(
                            self.tokenizer.decode(
                                input_ids[0],
                                skip_special_tokens=True,
                                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                            )
                        )

                    record["generated_text"] = prompt_text + text[prompt_length:]
910
911
912
913
914
915
916
917
918
919

                result.append(record)
            results += [result]

        if len(results) == 1:
            return results[0]

        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
920
921
922
923
924
925
926
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        return_all_scores (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether to return all prediction scores or just the one of the predicted class.
    """,
)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
927
class TextClassificationPipeline(Pipeline):
928
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
929
930
    Text classification pipeline using any :obj:`ModelForSequenceClassification`. See the
    `sequence classification examples <../task_summary.html#sequence-classification>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
931

Sylvain Gugger's avatar
Sylvain Gugger committed
932
933
934
    This text classification pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"sentiment-analysis"` (for classifying sequences according to positive or negative
    sentiments).
Lysandre Debut's avatar
Lysandre Debut committed
935
936

    The models that this pipeline can use are models that have been fine-tuned on a sequence classification task.
937
938
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=text-classification>`__.
939
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
940

941
942
943
    def __init__(self, return_all_scores: bool = False, **kwargs):
        super().__init__(**kwargs)

944
945
946
947
948
949
        self.check_model_type(
            TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
        )

950
951
        self.return_all_scores = return_all_scores

952
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
        """
        Classify the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several textts (or one list of prompts) to classify.

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
            following keys:

            - **label** (:obj:`str`) -- The label predicted.
            - **score** (:obj:`float`) -- The corresponding probability.

            If ``self.return_all_scores=True``, one such dictionary is returned per label.
        """
969
        outputs = super().__call__(*args, **kwargs)
Zhiyu Lin's avatar
Zhiyu Lin committed
970
        scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
971
972
        if self.return_all_scores:
            return [
973
                [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(item)]
974
975
976
977
978
979
                for item in scores
            ]
        else:
            return [
                {"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores
            ]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
980
981


982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
class ZeroShotClassificationArgumentHandler(ArgumentHandler):
    """
    Handles arguments for zero-shot for text classification by turning each possible label into an NLI
    premise/hypothesis pair.
    """

    def _parse_labels(self, labels):
        if isinstance(labels, str):
            labels = [label.strip() for label in labels.split(",")]
        return labels

    def __call__(self, sequences, labels, hypothesis_template):
        if len(labels) == 0 or len(sequences) == 0:
            raise ValueError("You must include at least one label and at least one sequence.")
        if hypothesis_template.format(labels[0]) == hypothesis_template:
            raise ValueError(
                (
                    'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
                    "Make sure the passed template includes formatting syntax such as {{}} where the label should go."
                ).format(hypothesis_template)
            )

        if isinstance(sequences, str):
            sequences = [sequences]
        labels = self._parse_labels(labels)

        sequence_pairs = []
        for sequence in sequences:
            sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])

        return sequence_pairs


Sylvain Gugger's avatar
Sylvain Gugger committed
1015
@add_end_docstrings(PIPELINE_INIT_ARGS)
1016
1017
class ZeroShotClassificationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1018
1019
    NLI-based zero-shot classification pipeline using a :obj:`ModelForSequenceClassification` trained on NLI (natural
    language inference) tasks.
1020
1021

    Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
Sylvain Gugger's avatar
Sylvain Gugger committed
1022
    pair and passed to the pretrained model. Then, the logit for `entailment` is taken as the logit for the
1023
1024
1025
    candidate label being valid. Any NLI model can be used as long as the first output logit corresponds to
    `contradiction` and the last to `entailment`.

Sylvain Gugger's avatar
Sylvain Gugger committed
1026
1027
    This NLI pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"zero-shot-classification"`.
1028

Sylvain Gugger's avatar
Sylvain Gugger committed
1029
    The models that this pipeline can use are models that have been fine-tuned on an NLI task.
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?search=nli>`__.
    """

    def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
        super().__init__(*args, args_parser=args_parser, **kwargs)

    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
        """
        Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
        """
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer(
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            truncation="only_first",
        )

        return inputs

    def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1054
        Classify the sequence(s) given as inputs.
1055
1056

        Args:
1057
            sequences (:obj:`str` or :obj:`List[str]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1058
                The sequence(s) to classify, will be truncated if the model input is too large.
1059
            candidate_labels (:obj:`str` or :obj:`List[str]`):
1060
1061
                The set of possible class labels to classify each sequence into. Can be a single label, a string of
                comma-separated labels, or a list of labels.
1062
            hypothesis_template (:obj:`str`, `optional`, defaults to :obj:`"This example is {}."`):
1063
1064
                The template used to turn each label into an NLI-style hypothesis. This template must include a {}
                or similar syntax for the candidate label to be inserted into the template. For example, the default
Sylvain Gugger's avatar
Sylvain Gugger committed
1065
1066
1067
1068
                template is :obj:`"This example is {}."` With the candidate label :obj:`"sports"`, this would be fed
                into the model like :obj:`"<cls> sequence to classify <sep> This example is sports . <sep>"`. The
                default template works well in many cases, but it may be worthwhile to experiment with different
                templates depending on the task setting.
1069
            multi_class (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1070
1071
1072
                Whether or not multiple candidate labels can be true. If :obj:`False`, the scores are normalized
                such that the sum of the label likelihoods for each sequence is 1. If :obj:`True`, the labels are
                considered independent and probabilities are normalized for each candidate by doing a softmax of
1073
                the entailment score vs. the contradiction score.
1074

Sylvain Gugger's avatar
Sylvain Gugger committed
1075
1076
1077
1078
1079
1080
        Return:
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:

            - **sequence** (:obj:`str`) -- The sequence for which this is the output.
            - **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
1081
            - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        """
        outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
        num_sequences = 1 if isinstance(sequences, str) else len(sequences)
        candidate_labels = self._args_parser._parse_labels(candidate_labels)
        reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))

        if len(candidate_labels) == 1:
            multi_class = True

        if not multi_class:
            # softmax the "entailment" logits over all candidate labels
            entail_logits = reshaped_outputs[..., -1]
            scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
        else:
            # softmax over the entailment vs. contradiction dim for each label independently
            entail_contr_logits = reshaped_outputs[..., [0, -1]]
            scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
            scores = scores[..., 1]

        result = []
        for iseq in range(num_sequences):
            top_inds = list(reversed(scores[iseq].argsort()))
            result.append(
                {
1106
                    "sequence": sequences if isinstance(sequences, str) else sequences[iseq],
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
                    "labels": [candidate_labels[i] for i in top_inds],
                    "scores": scores[iseq][top_inds].tolist(),
                }
            )

        if len(result) == 1:
            return result[0]
        return result


Sylvain Gugger's avatar
Sylvain Gugger committed
1117
1118
1119
1120
1121
1122
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        topk (:obj:`int`, defaults to 5): The number of predictions to return.
    """,
)
Julien Chaumond's avatar
Julien Chaumond committed
1123
1124
class FillMaskPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1125
1126
    Masked language modeling prediction pipeline using any :obj:`ModelWithLMHead`. See the
    `masked language modeling examples <../task_summary.html#masked-language-modeling>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1127

Sylvain Gugger's avatar
Sylvain Gugger committed
1128
1129
    This mask filling pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"fill-mask"`.
Lysandre Debut's avatar
Lysandre Debut committed
1130
1131
1132

    The models that this pipeline can use are models that have been trained with a masked language modeling objective,
    which includes the bi-directional models in the library.
1133
1134
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=lm-head>`__.
Lysandre Debut's avatar
Lysandre Debut committed
1135

Sylvain Gugger's avatar
Sylvain Gugger committed
1136
    .. note::
Lysandre Debut's avatar
Lysandre Debut committed
1137

Sylvain Gugger's avatar
Sylvain Gugger committed
1138
        This pipeline only works for inputs with exactly one token masked.
Julien Chaumond's avatar
Julien Chaumond committed
1139
1140
1141
1142
    """

    def __init__(
        self,
1143
1144
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1145
        modelcard: Optional[ModelCard] = None,
Julien Chaumond's avatar
Julien Chaumond committed
1146
1147
1148
1149
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        topk=5,
1150
        task: str = "",
Julien Chaumond's avatar
Julien Chaumond committed
1151
1152
1153
1154
1155
1156
1157
1158
1159
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
1160
            task=task,
Julien Chaumond's avatar
Julien Chaumond committed
1161
1162
        )

1163
        self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
1164

Julien Chaumond's avatar
Julien Chaumond committed
1165
1166
        self.topk = topk

1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
    def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
        numel = np.prod(masked_index.shape)
        if numel > 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
            )
        elif numel < 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"No mask_token ({self.tokenizer.mask_token}) found on the input",
            )

1182
    def __call__(self, *args, targets=None, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1183
1184
1185
1186
        """
        Fill the masked token in the text(s) given as inputs.

        Args:
1187
1188
1189
1190
1191
1192
            args (:obj:`str` or :obj:`List[str]`):
                One or several texts (or one list of prompts) with masked tokens.
            targets (:obj:`str` or :obj:`List[str]`, `optional`):
                When passed, the model will return the scores for the passed token or tokens rather than the top k
                predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
                be tokenized and the first resulting token will be used (with a warning).
Sylvain Gugger's avatar
Sylvain Gugger committed
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
            following keys:

            - **sequence** (:obj:`str`) -- The corresponding input with the mask token prediction.
            - **score** (:obj:`float`) -- The corresponding probability.
            - **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
            - **token** (:obj:`str`) -- The predicted token (to replace the masked one).
        """
Julien Chaumond's avatar
Julien Chaumond committed
1203
1204
1205
1206
1207
1208
        inputs = self._parse_and_tokenize(*args, **kwargs)
        outputs = self._forward(inputs, return_tensors=True)

        results = []
        batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)

1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
        if targets is not None:
            if len(targets) == 0 or len(targets[0]) == 0:
                raise ValueError("At least one target must be provided when passed.")
            if isinstance(targets, str):
                targets = [targets]

            targets_proc = []
            for target in targets:
                target_enc = self.tokenizer.tokenize(target)
                if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
                    logger.warning(
                        "The specified target token `{}` does not exist in the model vocabulary. Replacing with `{}`.".format(
                            target, target_enc[0]
                        )
                    )
                targets_proc.append(target_enc[0])
            target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))

Julien Chaumond's avatar
Julien Chaumond committed
1227
1228
1229
1230
1231
        for i in range(batch_size):
            input_ids = inputs["input_ids"][i]
            result = []

            if self.framework == "tf":
1232
1233
1234
1235
1236
1237
                masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index)

                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1238
                probs = tf.nn.softmax(logits)
1239
1240
1241
1242
1243
1244
1245
1246
                if targets is None:
                    topk = tf.math.top_k(probs, k=self.topk)
                    values, predictions = topk.values.numpy(), topk.indices.numpy()
                else:
                    values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
                    sort_inds = tf.reverse(tf.argsort(values), [0])
                    values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
                    predictions = target_inds[sort_inds.numpy()]
Julien Chaumond's avatar
Julien Chaumond committed
1247
            else:
1248
                masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
1249
1250
1251

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index.numpy())
1252

1253
                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1254
                probs = logits.softmax(dim=0)
1255
1256
1257
1258
1259
1260
1261
                if targets is None:
                    values, predictions = probs.topk(self.topk)
                else:
                    values = probs[..., target_inds]
                    sort_inds = list(reversed(values.argsort(dim=-1)))
                    values = values[..., sort_inds]
                    predictions = target_inds[sort_inds]
Julien Chaumond's avatar
Julien Chaumond committed
1262
1263
1264
1265
1266
1267

            for v, p in zip(values.tolist(), predictions.tolist()):
                tokens = input_ids.numpy()
                tokens[masked_index] = p
                # Filter padding out:
                tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
1268
1269
1270
1271
1272
1273
1274
1275
                result.append(
                    {
                        "sequence": self.tokenizer.decode(tokens),
                        "score": v,
                        "token": p,
                        "token_str": self.tokenizer.convert_ids_to_tokens(p),
                    }
                )
Julien Chaumond's avatar
Julien Chaumond committed
1276
1277
1278
1279
1280
1281
1282
1283
1284

            # Append
            results += [result]

        if len(results) == 1:
            return results[0]
        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
1285
1286
1287
1288
1289
1290
1291
1292
1293
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`):
            A list of labels to ignore.
        grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to group the tokens corresponding to the same entity together in the predictions or not.
    """,
)
1294
class TokenClassificationPipeline(Pipeline):
1295
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1296
1297
    Named Entity Recognition pipeline using any :obj:`ModelForTokenClassification`. See the
    `named entity recognition examples <../task_summary.html#named-entity-recognition>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1298

Sylvain Gugger's avatar
Sylvain Gugger committed
1299
1300
1301
    This token recognition pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location
    or miscellaneous).
Lysandre Debut's avatar
Lysandre Debut committed
1302
1303

    The models that this pipeline can use are models that have been fine-tuned on a token classification task.
1304
1305
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=token-classification>`__.
1306
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1307

1308
1309
1310
1311
    default_input_names = "sequences"

    def __init__(
        self,
1312
1313
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1314
        modelcard: Optional[ModelCard] = None,
1315
1316
1317
1318
1319
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
        ignore_labels=["O"],
1320
        task: str = "",
1321
        grouped_entities: bool = False,
1322
1323
1324
1325
1326
1327
1328
1329
1330
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=binary_output,
1331
            task=task,
1332
        )
1333

1334
1335
1336
1337
1338
1339
        self.check_model_type(
            TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
        )

1340
        self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
1341
        self.ignore_labels = ignore_labels
1342
        self.grouped_entities = grouped_entities
1343

1344
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
        """
        Classify each token of the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several texts (or one list of texts) for token classification.

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in
            the corresponding input, or each entity if this pipeline was instantiated with
            :obj:`grouped_entities=True`) with the following keys:

            - **word** (:obj:`str`) -- The token/word classified.
            - **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
            - **entity** (:obj:`str`) -- The entity predicted for that token/word.
            - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
              corresponding token in the sentence.
        """
1363
        inputs = self._args_parser(*args, **kwargs)
Julien Chaumond's avatar
Julien Chaumond committed
1364
        answers = []
1365
        for sentence in inputs:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1366

1367
1368
            # Manage correct placement of the tensors
            with self.device_placement():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1369

1370
                tokens = self.tokenizer(
Lysandre's avatar
Lysandre committed
1371
1372
1373
1374
                    sentence,
                    return_attention_mask=False,
                    return_tensors=self.framework,
                    truncation=True,
1375
                )
1376
1377

                # Forward
1378
                if self.framework == "tf":
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1379
                    entities = self.model(tokens.data)[0][0].numpy()
1380
                    input_ids = tokens["input_ids"].numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1381
                else:
1382
                    with torch.no_grad():
1383
                        tokens = self.ensure_tensor_on_device(**tokens)
1384
                        entities = self.model(**tokens)[0][0].cpu().numpy()
1385
                        input_ids = tokens["input_ids"].cpu().numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1386

thomwolf's avatar
thomwolf committed
1387
1388
            score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
            labels_idx = score.argmax(axis=-1)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1389

1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
            entities = []
            # Filter to labels not in `self.ignore_labels`
            filtered_labels_idx = [
                (idx, label_idx)
                for idx, label_idx in enumerate(labels_idx)
                if self.model.config.id2label[label_idx] not in self.ignore_labels
            ]

            for idx, label_idx in filtered_labels_idx:

                entity = {
                    "word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
                    "score": score[idx][label_idx].item(),
                    "entity": self.model.config.id2label[label_idx],
                    "index": idx,
                }

                entities += [entity]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1408

1409
            # Append grouped entities
1410
            if self.grouped_entities:
1411
1412
                answers += [self.group_entities(entities)]
            # Append ungrouped entities
1413
1414
1415
            else:
                answers += [entities]

thomwolf's avatar
thomwolf committed
1416
1417
        if len(answers) == 1:
            return answers[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1418
1419
        return answers

1420
    def group_sub_entities(self, entities: List[dict]) -> dict:
1421
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1422
1423
1424
1425
        Group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1426
        """
1427
1428
        # Get the first entity in the entity group
        entity = entities[0]["entity"]
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
        scores = np.mean([entity["score"] for entity in entities])
        tokens = [entity["word"] for entity in entities]

        entity_group = {
            "entity_group": entity,
            "score": np.mean(scores),
            "word": self.tokenizer.convert_tokens_to_string(tokens),
        }
        return entity_group

1439
1440
    def group_entities(self, entities: List[dict]) -> List[dict]:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1441
1442
1443
1444
        Find and group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
        """

        entity_groups = []
        entity_group_disagg = []

        if entities:
            last_idx = entities[-1]["index"]

        for entity in entities:
            is_last_idx = entity["index"] == last_idx
            if not entity_group_disagg:
                entity_group_disagg += [entity]
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
                continue

            # If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
            # The split is meant to account for the "B" and "I" suffixes
            if (
                entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
                and entity["index"] == entity_group_disagg[-1]["index"] + 1
            ):
                entity_group_disagg += [entity]
                # Group the entities at the last entity
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
            # If the current entity is different from the previous entity, aggregate the disaggregated entity group
            else:
                entity_groups += [self.group_sub_entities(entity_group_disagg)]
                entity_group_disagg = [entity]
                # If it's the last entity, add it to the entity groups
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]

        return entity_groups

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1481

1482
NerPipeline = TokenClassificationPipeline
1483
1484


1485
1486
1487
class QuestionAnsweringArgumentHandler(ArgumentHandler):
    """
    QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped
Sylvain Gugger's avatar
Sylvain Gugger committed
1488
    to internal :class:`~transformers.SquadExample`.
1489

Sylvain Gugger's avatar
Sylvain Gugger committed
1490
1491
    QuestionAnsweringArgumentHandler manages all the possible to create a :class:`~transformers.SquadExample` from
    the command-line supplied arguments.
1492
    """
1493

1494
1495
1496
1497
    def __call__(self, *args, **kwargs):
        # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
        if args is not None and len(args) > 0:
            if len(args) == 1:
1498
                kwargs["X"] = args[0]
1499
            else:
1500
                kwargs["X"] = list(args)
1501

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1502
1503
        # Generic compatibility with sklearn and Keras
        # Batched data
1504
1505
        if "X" in kwargs or "data" in kwargs:
            inputs = kwargs["X"] if "X" in kwargs else kwargs["data"]
1506

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1507
1508
1509
1510
1511
            if isinstance(inputs, dict):
                inputs = [inputs]
            else:
                # Copy to avoid overriding arguments
                inputs = [i for i in inputs]
1512

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1513
            for i, item in enumerate(inputs):
1514
                if isinstance(item, dict):
1515
1516
                    if any(k not in item for k in ["question", "context"]):
                        raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
1517

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1518
1519
1520
                    inputs[i] = QuestionAnsweringPipeline.create_sample(**item)

                elif not isinstance(item, SquadExample):
1521
                    raise ValueError(
1522
1523
1524
                        "{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format(
                            "X" if "X" in kwargs else "data"
                        )
1525
1526
1527
                    )

            # Tabular input
1528
1529
1530
        elif "question" in kwargs and "context" in kwargs:
            if isinstance(kwargs["question"], str):
                kwargs["question"] = [kwargs["question"]]
1531

1532
1533
            if isinstance(kwargs["context"], str):
                kwargs["context"] = [kwargs["context"]]
1534

1535
1536
1537
            inputs = [
                QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"])
            ]
1538
        else:
1539
            raise ValueError("Unknown arguments {}".format(kwargs))
1540
1541
1542
1543
1544
1545
1546

        if not isinstance(inputs, list):
            inputs = [inputs]

        return inputs


Sylvain Gugger's avatar
Sylvain Gugger committed
1547
@add_end_docstrings(PIPELINE_INIT_ARGS)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1548
1549
class QuestionAnsweringPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1550
1551
    Question Answering pipeline using any :obj:`ModelForQuestionAnswering`. See the
    `question answering examples <../task_summary.html#question-answering>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1552

Sylvain Gugger's avatar
Sylvain Gugger committed
1553
1554
    This question answering pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"question-answering"`.
Lysandre Debut's avatar
Lysandre Debut committed
1555
1556

    The models that this pipeline can use are models that have been fine-tuned on a question answering task.
1557
1558
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=question-answering>`__.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1559
1560
    """

1561
1562
1563
1564
    default_input_names = "question,context"

    def __init__(
        self,
1565
1566
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1567
        modelcard: Optional[ModelCard] = None,
1568
1569
        framework: Optional[str] = None,
        device: int = -1,
1570
        task: str = "",
1571
1572
1573
1574
1575
1576
1577
1578
1579
        **kwargs
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=QuestionAnsweringArgumentHandler(),
            device=device,
1580
            task=task,
1581
            **kwargs,
1582
        )
thomwolf's avatar
thomwolf committed
1583

1584
1585
1586
1587
        self.check_model_type(
            TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
        )

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1588
    @staticmethod
1589
1590
1591
    def create_sample(
        question: Union[str, List[str]], context: Union[str, List[str]]
    ) -> Union[SquadExample, List[SquadExample]]:
1592
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1593
1594
1595
1596
        QuestionAnsweringPipeline leverages the :class:`~transformers.SquadExample` internally.
        This helper method encapsulate all the logic for converting question(s) and context(s) to
        :class:`~transformers.SquadExample`.

1597
        We currently support extractive question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
1598

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1599
        Arguments:
Sylvain Gugger's avatar
Sylvain Gugger committed
1600
1601
            question (:obj:`str` or :obj:`List[str]`): The question(s) asked.
            context (:obj:`str` or :obj:`List[str]`): The context(s) in which we will look for the answer.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1602
1603

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1604
1605
            One or a list of :class:`~transformers.SquadExample`: The corresponding
            :class:`~transformers.SquadExample` grouping question and context.
1606
1607
        """
        if isinstance(question, list):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1608
1609
1610
1611
            return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
        else:
            return SquadExample(None, question, context, None, None, None)

1612
    def __call__(self, *args, **kwargs):
1613
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1614
1615
        Answer the question(s) given as inputs by using the context(s).

1616
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
            args (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`):
                One or several :class:`~transformers.SquadExample` containing the question and context.
            X (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
                One or several :class:`~transformers.SquadExample` containing the question and context
                (will be treated the same way as if passed as the first positional argument).
            data (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
                One or several :class:`~transformers.SquadExample` containing the question and context
                (will be treated the same way as if passed as the first positional argument).
            question (:obj:`str` or :obj:`List[str]`):
                One or several question(s) (must be used in conjunction with the :obj:`context` argument).
            context (:obj:`str` or :obj:`List[str]`):
                One or several context(s) associated with the qustion(s) (must be used in conjunction with the
                :obj:`question` argument).
            topk (:obj:`int`, `optional`, defaults to 1):
                The number of answers to return (will be chosen by order of likelihood).
            doc_stride (:obj:`int`, `optional`, defaults to 128):
                If the context is too long to fit with the question for the model, it will be split in several chunks
                with some overlap. This argument controls the size of that overlap.
            max_answer_len (:obj:`int`, `optional`, defaults to 15):
                The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
            max_seq_len (:obj:`int`, `optional`, defaults to 384):
                The maximum length of the total sentence (context + question) after tokenization. The context will be
                split in several chunks (using :obj:`doc_stride`) if needed.
            max_question_len (:obj:`int`, `optional`, defaults to 64):
                The maximum length of the question after tokenization. It will be truncated if needed.
            handle_impossible_answer (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not we accept impossible as an answer.

        Return:
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:

            - **score** (:obj:`float`) -- The probability associated to the answer.
            - **start** (:obj:`int`) -- The start index of the answer (in the tokenized version of the input).
            - **end** (:obj:`int`) -- The end index of the answer (in the tokenized version of the input).
            - **answer** (:obj:`str`) -- The answer to the question.
1653
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1654
        # Set defaults values
1655
1656
1657
1658
1659
        kwargs.setdefault("topk", 1)
        kwargs.setdefault("doc_stride", 128)
        kwargs.setdefault("max_answer_len", 15)
        kwargs.setdefault("max_seq_len", 384)
        kwargs.setdefault("max_question_len", 64)
1660
        kwargs.setdefault("handle_impossible_answer", False)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1661

1662
1663
        if kwargs["topk"] < 1:
            raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1664

1665
1666
        if kwargs["max_answer_len"] < 1:
            raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1667
1668

        # Convert inputs to features
1669
        examples = self._args_parser(*args, **kwargs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1670
1671
        features_list = [
            squad_convert_examples_to_features(
1672
1673
1674
1675
1676
                examples=[example],
                tokenizer=self.tokenizer,
                max_seq_length=kwargs["max_seq_len"],
                doc_stride=kwargs["doc_stride"],
                max_query_length=kwargs["max_question_len"],
1677
                padding_strategy=PaddingStrategy.DO_NOT_PAD.value,
1678
                is_training=False,
1679
                tqdm_enabled=False,
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1680
1681
1682
            )
            for example in examples
        ]
Rishabh Manoj's avatar
Rishabh Manoj committed
1683
1684
        all_answers = []
        for features, example in zip(features_list, examples):
Patrick von Platen's avatar
Patrick von Platen committed
1685
1686
            model_input_names = self.tokenizer.model_input_names + ["input_ids"]
            fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
Rishabh Manoj's avatar
Rishabh Manoj committed
1687
1688
1689
1690
1691

            # Manage tensor allocation on correct device
            with self.device_placement():
                if self.framework == "tf":
                    fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1692
                    start, end = self.model(fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1693
1694
1695
1696
1697
                    start, end = start.numpy(), end.numpy()
                else:
                    with torch.no_grad():
                        # Retrieve the score for the context tokens only (removing question tokens)
                        fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1698
                        start, end = self.model(**fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1699
1700
                        start, end = start.cpu().numpy(), end.cpu().numpy()

1701
            min_null_score = 1000000  # large and positive
Rishabh Manoj's avatar
Rishabh Manoj committed
1702
1703
            answers = []
            for (feature, start_, end_) in zip(features, start, end):
1704
1705
                # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
                undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask
Rishabh Manoj's avatar
Rishabh Manoj committed
1706

1707
1708
1709
1710
1711
1712
                # Generate mask
                undesired_tokens_mask = undesired_tokens == 0.0

                # Make sure non-context indexes in the tensor cannot contribute to the softmax
                start_ = np.where(undesired_tokens_mask, -10000.0, start_)
                end_ = np.where(undesired_tokens_mask, -10000.0, end_)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1713
1714
1715
1716
1717

                # Normalize logits and spans to retrieve the answer
                start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
                end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

1718
1719
1720
                if kwargs["handle_impossible_answer"]:
                    min_null_score = min(min_null_score, (start_[0] * end_[0]).item())

1721
1722
1723
                # Mask CLS
                start_[0] = end_[0] = 0.0

Rishabh Manoj's avatar
Rishabh Manoj committed
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
                starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
                char_to_word = np.array(example.char_to_word_offset)

                # Convert the answer (tokens) back to the original text
                answers += [
                    {
                        "score": score.item(),
                        "start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
                        "end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
                        "answer": " ".join(
                            example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
                        ),
                    }
                    for s, e, score in zip(starts, ends, scores)
                ]
1739
1740
1741
1742

            if kwargs["handle_impossible_answer"]:
                answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1743
1744
1745
            answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
            all_answers += answers

Rishabh Manoj's avatar
Rishabh Manoj committed
1746
        if len(all_answers) == 1:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1747
            return all_answers[0]
Rishabh Manoj's avatar
Rishabh Manoj committed
1748
        return all_answers
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1749
1750

    def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
1751
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1752
        Take the output of any :obj:`ModelForQuestionAnswering` and will generate probalities for each span to be
1753
        the actual answer.
Sylvain Gugger's avatar
Sylvain Gugger committed
1754

1755
1756
1757
1758
1759
        In addition, it filters out some unwanted/impossible cases like answer len being greater than
        max_answer_len or answer end position being before the starting position.
        The method supports output the k-best answer through the topk argument.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1760
1761
1762
1763
            start (:obj:`np.ndarray`): Individual start probabilities for each token.
            end (:obj:`np.ndarray`): Individual end probabilities for each token.
            topk (:obj:`int`): Indicates how many possible answer span(s) to extract from the model output.
            max_answer_len (:obj:`int`): Maximum size of the answer to extract from the model's output.
1764
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
        # Ensure we have batch axis
        if start.ndim == 1:
            start = start[None]

        if end.ndim == 1:
            end = end[None]

        # Compute the score of each tuple(start, end) to be the real answer
        outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))

        # Remove candidate with end < start and end - start > max_answer_len
        candidates = np.tril(np.triu(outer), max_answer_len - 1)

        #  Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
        scores_flat = candidates.flatten()
        if topk == 1:
            idx_sort = [np.argmax(scores_flat)]
        elif len(scores_flat) < topk:
            idx_sort = np.argsort(-scores_flat)
        else:
            idx = np.argpartition(-scores_flat, topk)[0:topk]
            idx_sort = idx[np.argsort(-scores_flat[idx])]

        start, end = np.unravel_index(idx_sort, candidates.shape)[1:]
        return start, end, candidates[0, start, end]

Sylvain Gugger's avatar
Sylvain Gugger committed
1791
    def span_to_answer(self, text: str, start: int, end: int) -> Dict[str, Union[str, int]]:
1792
1793
1794
1795
1796
        """
        When decoding from token probalities, this method maps token indexes to actual word in
        the initial context.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1797
1798
1799
            text (:obj:`str`): The actual context to extract the answer from.
            start (:obj:`int`): The answer starting token index.
            end (:obj:`int`): The answer end token index.
1800
1801

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1802
            Dictionary like :obj:`{'answer': str, 'start': int, 'end': int}`
1803
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
        words = []
        token_idx = char_start_idx = char_end_idx = chars_idx = 0

        for i, word in enumerate(text.split(" ")):
            token = self.tokenizer.tokenize(word)

            # Append words if they are in the span
            if start <= token_idx <= end:
                if token_idx == start:
                    char_start_idx = chars_idx

                if token_idx == end:
                    char_end_idx = chars_idx + len(word)

                words += [word]

            # Stop if we went over the end of the answer
            if token_idx > end:
                break

            # Append the subtokenization length to the running index
            token_idx += len(token)
            chars_idx += len(word) + 1

        # Join text with spaces
1829
1830
1831
1832
1833
        return {
            "answer": " ".join(words),
            "start": max(0, char_start_idx),
            "end": min(len(text), char_end_idx),
        }
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1834
1835


Sylvain Gugger's avatar
Sylvain Gugger committed
1836
@add_end_docstrings(PIPELINE_INIT_ARGS)
1837
1838
class SummarizationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1839
1840
1841
1842
1843
1844
1845
1846
1847
    Summarize news articles and other documents.

    This summarizing pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"summarization"`.

    The models that this pipeline can use are models that have been fine-tuned on a summarization task,
    which is currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'.
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=summarization>`__.
1848
1849
1850

    Usage::

1851
        # use bart in pytorch
1852
        summarizer = pipeline("summarization")
1853
1854
1855
1856
1857
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)

        # use t5 in tf
        summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
1858
1859
    """

1860
    def __init__(self, *args, **kwargs):
1861
        kwargs.update(task="summarization")
1862
1863
1864
1865
1866
        super().__init__(*args, **kwargs)

        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
1867

1868
    def __call__(
1869
        self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
1870
1871
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1872
        Summarize the text(s) given as inputs.
1873

Sylvain Gugger's avatar
Sylvain Gugger committed
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
        Args:
            documents (`str` or :obj:`List[str]`):
                One or several articles (or one list of articles) to summarize.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
1886

Sylvain Gugger's avatar
Sylvain Gugger committed
1887
1888
1889
        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
1890

Sylvain Gugger's avatar
Sylvain Gugger committed
1891
1892
1893
1894
            - **summary_text** (:obj:`str`, present when ``return_text=True``) -- The summary of the corresponding
              input.
            - **summary_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the summary.
1895
1896
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
        assert len(documents) > 0, "Please provide a document to summarize"

        if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
            raise NotImplementedError(
                "Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
            )

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

        if isinstance(documents[0], list):
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"

            documents = ([prefix + document for document in documents[0]],)
1912
            padding = True
1913
1914
1915

        elif isinstance(documents[0], str):
            documents = (prefix + documents[0],)
1916
            padding = False
1917
1918
1919
1920
1921
1922
1923
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
                    documents[0]
                )
            )

1924
        with self.device_placement():
1925
            inputs = self._parse_and_tokenize(*documents, padding=padding)
1926
1927
1928
1929
1930

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]
            elif self.framework == "tf":
1931
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()
1932

1933
1934
            min_length = generate_kwargs.get("min_length", self.model.config.min_length)
            if input_length < min_length // 2:
1935
                logger.warning(
1936
                    "Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
1937
                        min_length, input_length
1938
1939
1940
                    )
                )

1941
1942
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length < max_length:
1943
                logger.warning(
1944
                    "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
1945
                        max_length, input_length
1946
1947
1948
                    )
                )

1949
            summaries = self.model.generate(
Lysandre's avatar
Lysandre committed
1950
1951
1952
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
1953
            )
1954

1955
1956
1957
1958
1959
1960
1961
            results = []
            for summary in summaries:
                record = {}
                if return_tensors:
                    record["summary_token_ids"] = summary
                if return_text:
                    record["summary_text"] = self.tokenizer.decode(
Lysandre's avatar
Lysandre committed
1962
1963
1964
                        summary,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
1965
1966
1967
1968
1969
                    )
                results.append(record)
            return results


Sylvain Gugger's avatar
Sylvain Gugger committed
1970
@add_end_docstrings(PIPELINE_INIT_ARGS)
1971
1972
1973
1974
class TranslationPipeline(Pipeline):
    """
    Translates from one language to another.

Sylvain Gugger's avatar
Sylvain Gugger committed
1975
1976
    This translation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"translation_xx_to_yy"`.
1977

Sylvain Gugger's avatar
Sylvain Gugger committed
1978
    The models that this pipeline can use are models that have been fine-tuned on a translation task.
1979
1980
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=translation>`__.
1981

Sylvain Gugger's avatar
Sylvain Gugger committed
1982
1983
1984
    Usage::
        en_fr_translator = pipeline("translation_en_to_fr")
        en_fr_translator("How old are you?")
1985
1986
    """

1987
1988
1989
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

1990
1991
1992
        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
1993

1994
    def __call__(
1995
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
1996
1997
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1998
1999
        Translate the text(s) given as inputs.

2000
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
            args (:obj:`str` or :obj:`List[str]`):
                Texts to be translated.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
2012

Sylvain Gugger's avatar
Sylvain Gugger committed
2013
2014
2015
        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
2016

Sylvain Gugger's avatar
Sylvain Gugger committed
2017
2018
2019
            - **translation_text** (:obj:`str`, present when ``return_text=True``) -- The translation.
            - **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the translation.
2020
2021
2022
2023
2024
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

2025
        if isinstance(args[0], list):
2026
2027
2028
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
2029
            args = ([prefix + text for text in args[0]],)
2030
            padding = True
2031

2032
2033
        elif isinstance(args[0], str):
            args = (prefix + args[0],)
2034
            padding = False
2035
2036
2037
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
2038
                    args[0]
2039
2040
2041
2042
                )
            )

        with self.device_placement():
2043
            inputs = self._parse_and_tokenize(*args, padding=padding)
2044
2045
2046
2047
2048
2049
2050
2051

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

2052
2053
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length > 0.9 * max_length:
2054
2055
                logger.warning(
                    "Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
2056
                        input_length, max_length
2057
2058
2059
2060
                    )
                )

            translations = self.model.generate(
Lysandre's avatar
Lysandre committed
2061
2062
2063
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
            )
            results = []
            for translation in translations:
                record = {}
                if return_tensors:
                    record["translation_token_ids"] = translation
                if return_text:
                    record["translation_text"] = self.tokenizer.decode(
                        translation,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
2075
2076
2077
2078
2079
                    )
                results.append(record)
            return results


2080
2081
2082
class Conversation:
    """
    Utility class containing a conversation and its history. This class is meant to be used as an input to the
Sylvain Gugger's avatar
Sylvain Gugger committed
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
    :class:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the
    addition of new user input and generated model responses. A conversation needs to contain an unprocessed user input
    before being passed to the :class:`~transformers.ConversationalPipeline`. This user input is either created when
    the class is instantiated, or by calling :obj:`conversional_pipeline.append_response("input")` after a conversation
    turn.

    Arguments:
        text (:obj:`str`, `optional`):
            The initial user input to start the conversation. If not provided, a user input needs to be provided
            manually using the :meth:`~transformers.Conversation.add_user_input` method before the conversation can
            begin.
        conversation_id (:obj:`uuid.UUID`, `optional`):
            Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
            conversation.
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121

    Usage::

        conversation = Conversation("Going to the movies tonight - any suggestions?")

        # Steps usually performed by the model when generating a response:
        # 1. Mark the user input as processed (moved to the history)
        conversation.mark_processed()
        # 2. Append a mode response
        conversation.append_response("The Big lebowski.")

        conversation.add_user_input("Is it good?")
    """

    def __init__(self, text: str = None, conversation_id: UUID = None):
        if not conversation_id:
            conversation_id = uuid.uuid4()
        self.uuid: UUID = conversation_id
        self.past_user_inputs: List[str] = []
        self.generated_responses: List[str] = []
        self.history: List[int] = []
        self.new_user_input: Optional[str] = text

    def add_user_input(self, text: str, overwrite: bool = False):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2122
2123
        Add a user input to the conversation for the next round. This populates the internal :obj:`new_user_input`
        field.
2124
2125

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2126
2127
2128
            text (:obj:`str`): The user input for the next conversation round.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not existing and unprocessed user input should be overwritten when this function is called.
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
        """
        if self.new_user_input:
            if overwrite:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format(
                        self.new_user_input, text
                    )
                )
                self.new_user_input = text
            else:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" new input ignored: "{}". '
                    "Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text)
                )
        else:
            self.new_user_input = text

    def mark_processed(self):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2148
2149
        Mark the conversation as processed (moves the content of :obj:`new_user_input` to :obj:`past_user_inputs`) and
        empties the :obj:`new_user_input` field.
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
        """
        if self.new_user_input:
            self.past_user_inputs.append(self.new_user_input)
        self.new_user_input = None

    def append_response(self, response: str):
        """
        Append a response to the list of generated responses.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2160
            response (:obj:`str`): The model generated response.
2161
2162
2163
2164
2165
        """
        self.generated_responses.append(response)

    def set_history(self, history: List[int]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2166
2167
        Updates the value of the history of the conversation. The history is represented by a list of :obj:`token_ids`.
        The history is used by the model to generate responses based on the previous conversation turns.
2168
2169

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2170
            history (:obj:`List[int]`): History of tokens provided and generated for this conversation.
2171
2172
2173
2174
2175
2176
2177
2178
        """
        self.history = history

    def __repr__(self):
        """
        Generates a string representation of the conversation.

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2179
            :obj:`str`:
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194

            Example:
            Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114
            user >> Going to the movies tonight - any suggestions?
            bot >> The Big Lebowski
        """
        output = "Conversation id: {} \n".format(self.uuid)
        for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
            output += "user >> {} \n".format(user_input)
            output += "bot >> {} \n".format(generated_response)
        if self.new_user_input is not None:
            output += "user >> {} \n".format(self.new_user_input)
        return output


Sylvain Gugger's avatar
Sylvain Gugger committed
2195
2196
2197
2198
2199
2200
2201
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        min_length_for_response (:obj:`int`, `optional`, defaults to 32):
            The minimum length (in number of tokens) for a response.
    """,
)
2202
2203
2204
2205
class ConversationalPipeline(Pipeline):
    """
    Multi-turn conversational pipeline.

Sylvain Gugger's avatar
Sylvain Gugger committed
2206
2207
2208
2209
2210
2211
2212
2213
    This conversational pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"conversational"`.

    The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task,
    currently: `'microsoft/DialoGPT-small'`, `'microsoft/DialoGPT-medium'`, `'microsoft/DialoGPT-large'`.
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=conversational>`__.

2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
    Usage::

        conversational_pipeline = pipeline("conversational")

        conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
        conversation_2 = Conversation("What's the last book you have read?")

        conversational_pipeline([conversation_1, conversation_2])

        conversation_1.add_user_input("Is it an action movie?")
        conversation_2.add_user_input("What is the genre of this book?")

        conversational_pipeline([conversation_1, conversation_2])
    """

    def __init__(self, min_length_for_response=32, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
        if self.tokenizer.pad_token_id is not None:
            self.pad_token_id = self.tokenizer.pad_token_id
        else:
            self.pad_token_id = self.tokenizer.eos_token_id
        self.min_length_for_response = min_length_for_response

    def __call__(
        self,
        conversations: Union[Conversation, List[Conversation]],
        clean_up_tokenization_spaces=True,
        **generate_kwargs
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
2245
2246
        Generate responses for the conversation(s) given as inputs.

2247
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2248
2249
2250
2251
2252
2253
2254
            conversations (a :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`):
                Conversations to generate responses for.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
2255
2256

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2257
2258
            :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`: Conversation(s) with
            updated generated responses for those containing a new user input.
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
        """

        # Input validation
        if isinstance(conversations, list):
            for conversation in conversations:
                assert isinstance(
                    conversation, Conversation
                ), "DialoguePipeline expects a Conversation or list of Conversations as an input"
                if conversation.new_user_input is None:
                    raise ValueError(
                        "Conversation with UUID {} does not contain new user input to process. "
                        "Add user inputs with the conversation's `add_user_input` method".format(
                            type(conversation.uuid)
                        )
                    )
            assert (
                self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
        elif isinstance(conversations, Conversation):
            conversations = [conversations]
        else:
            raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")

        with self.device_placement():

            inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
            histories = [conversation.history for conversation in conversations]
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            inputs = self._concat_inputs_history(inputs, histories, max_length)

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

            if input_length > 0.9 * max_length:
                logger.warning(
                    "Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
                    "You might consider trimming the early phase of the conversation".format(input_length, max_length)
                )
            generated_responses = self.model.generate(
Lysandre's avatar
Lysandre committed
2302
2303
2304
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
            )

            cleaned_history = self._clean_padding_history(generated_responses)
            output = []
            for conversation_index, conversation in enumerate(conversations):
                conversation.mark_processed()
                conversation.generated_responses.append(
                    self.tokenizer.decode(
                        cleaned_history[conversation_index][input_length:],
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                )
                conversation.set_history(cleaned_history[conversation_index])
                output.append(conversation)
            if len(output) == 1:
                return output[0]
            else:
                return output

    def _parse_and_tokenize(self, *args, **kwargs):
        """
        Parse arguments and tokenize, adding an EOS token at the end of the user input
        """
        # Parse arguments
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
        for input in inputs:
            input.append(self.tokenizer.eos_token_id)
        return inputs

    def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
        """
        Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
        an input:
            - at the end of the concatenated history and new user input, so that all input to the model have the same
                length
            - at the end of the generated response, as some responses will be longer than others
        This method cleans up these padding token so that the history for each conversation is not impacted by the
        batching process.
        """
        outputs = []
        for sequence in generated_tensor:
            sequence_tokens = []
            is_previous_pad = False
            for token in sequence:
                if token == self.pad_token_id:
                    if is_previous_pad:
                        continue
                    else:
                        is_previous_pad = True
                else:
                    is_previous_pad = False
                if self.framework == "pt":
                    sequence_tokens.append(token.item())
                else:
                    sequence_tokens.append(int(token.numpy()))

            outputs.append(sequence_tokens)
        return outputs

    def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
        """
        Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
        """
        outputs = []
        for new_input, history in zip(inputs, histories):
            if history is not None:
                new_input = history + new_input
            if len(new_input) > max_length - self.min_length_for_response:
                cutoff_eos_index = 0
                while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
                    if cutoff_eos_index >= len(new_input):
                        break
                    cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
                    if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
                        break
                    else:
                        new_input = new_input[cutoff_eos_index + 1 :]
            outputs.append(new_input)
        max_len = max([len(item) for item in outputs])
        outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
        outputs = BatchEncoding(
Lysandre's avatar
Lysandre committed
2388
2389
            {"input_ids": outputs, "attention_mask": [[1] * len(outputs)]},
            tensor_type=self.framework,
2390
2391
2392
2393
        )
        return outputs


2394
# Register all the supported tasks here
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2395
SUPPORTED_TASKS = {
2396
2397
2398
2399
    "feature-extraction": {
        "impl": FeatureExtractionPipeline,
        "tf": TFAutoModel if is_tf_available() else None,
        "pt": AutoModel if is_torch_available() else None,
2400
        "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
2401
    },
2402
2403
2404
2405
2406
2407
    "sentiment-analysis": {
        "impl": TextClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {
2408
2409
                "pt": "distilbert-base-uncased-finetuned-sst-2-english",
                "tf": "distilbert-base-uncased-finetuned-sst-2-english",
2410
            },
2411
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2412
    },
2413
    "ner": {
2414
        "impl": TokenClassificationPipeline,
2415
2416
2417
2418
        "tf": TFAutoModelForTokenClassification if is_tf_available() else None,
        "pt": AutoModelForTokenClassification if is_torch_available() else None,
        "default": {
            "model": {
Julien Chaumond's avatar
Julien Chaumond committed
2419
2420
                "pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
                "tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
2421
            },
2422
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2423
    },
2424
2425
2426
2427
2428
    "question-answering": {
        "impl": QuestionAnsweringPipeline,
        "tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
        "pt": AutoModelForQuestionAnswering if is_torch_available() else None,
        "default": {
Lysandre's avatar
E231  
Lysandre committed
2429
            "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
2430
2431
        },
    },
Julien Chaumond's avatar
Julien Chaumond committed
2432
2433
2434
    "fill-mask": {
        "impl": FillMaskPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2435
        "pt": AutoModelForMaskedLM if is_torch_available() else None,
2436
        "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
Julien Chaumond's avatar
Julien Chaumond committed
2437
    },
2438
2439
    "summarization": {
        "impl": SummarizationPipeline,
2440
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2441
2442
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
        "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
2443
    },
2444
2445
2446
    "translation_en_to_fr": {
        "impl": TranslationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2447
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2448
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
2449
2450
2451
2452
    },
    "translation_en_to_de": {
        "impl": TranslationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2453
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2454
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
2455
2456
2457
2458
    },
    "translation_en_to_ro": {
        "impl": TranslationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2459
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2460
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
2461
    },
2462
2463
2464
    "text-generation": {
        "impl": TextGenerationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2465
        "pt": AutoModelForCausalLM if is_torch_available() else None,
2466
        "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
2467
    },
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
    "zero-shot-classification": {
        "impl": ZeroShotClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
        },
    },
2478
2479
2480
2481
2482
2483
    "conversational": {
        "impl": ConversationalPipeline,
        "tf": TFAutoModelForCausalLM if is_tf_available() else None,
        "pt": AutoModelForCausalLM if is_torch_available() else None,
        "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
    },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2484
2485
2486
}


2487
2488
2489
2490
2491
def pipeline(
    task: str,
    model: Optional = None,
    config: Optional[Union[str, PretrainedConfig]] = None,
    tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
2492
    framework: Optional[str] = None,
2493
2494
    **kwargs
) -> Pipeline:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2495
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
2496
    Utility factory method to build a :class:`~transformers.Pipeline`.
Lysandre Debut's avatar
Lysandre Debut committed
2497

Sylvain Gugger's avatar
Sylvain Gugger committed
2498
    Pipelines are made of:
Lysandre Debut's avatar
Lysandre Debut committed
2499

Sylvain Gugger's avatar
Sylvain Gugger committed
2500
2501
2502
        - A :doc:`tokenizer <tokenizer>` in charge of mapping raw textual input to token.
        - A :doc:`model <model>` to make predictions from the inputs.
        - Some (optional) post processing for enhancing model's output.
Lysandre Debut's avatar
Lysandre Debut committed
2503
2504
2505
2506
2507

    Args:
        task (:obj:`str`):
            The task defining which pipeline will be returned. Currently accepted tasks are:

Sylvain Gugger's avatar
Sylvain Gugger committed
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
            - :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
            - :obj:`"sentiment-analysis"`: will return a :class:`~transformers.TextClassificationPipeline`.
            - :obj:`"ner"`: will return a :class:`~transformers.TokenClassificationPipeline`.
            - :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
            - :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
            - :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.
            - :obj:`"translation_xx_to_yy"`: will return a :class:`~transformers.TranslationPipeline`.
            - :obj:`"text-generation"`: will return a :class:`~transformers.TextGenerationPipeline`.
            - :obj:`"conversation"`: will return a :class:`~transformers.ConversationalPipeline`.
        model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`):
            The model that will be used by the pipeline to make predictions. This can be a model identifier or an
            actual instance of a pretrained model inheriting from :class:`~transformers.PreTrainedModel` (for PyTorch)
            or :class:`~transformers.TFPreTrainedModel` (for TensorFlow).

            If not provided, the default for the :obj:`task` will be loaded.
        config (:obj:`str` or :obj:`~transformers.PretrainedConfig`, `optional`):
            The configuration that will be used by the pipeline to instantiate the model. This can be a model
            identifier or an actual pretrained model configuration inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
2526
2527
            :class:`~transformers.PretrainedConfig`.

Sylvain Gugger's avatar
Sylvain Gugger committed
2528
2529
2530
2531
            If not provided, the default for the :obj:`task` will be loaded.
        tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`):
            The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
            identifier or an actual pretrained tokenizer inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
2532
2533
            :class:`~transformers.PreTrainedTokenizer`.

Sylvain Gugger's avatar
Sylvain Gugger committed
2534
2535
2536
2537
            If not provided, the default for the :obj:`task` will be loaded.
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
2538
2539

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
2540
2541
2542
2543
2544
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        kwargs:
            Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
            corresponding pipeline class for possible values).
Lysandre Debut's avatar
Lysandre Debut committed
2545
2546

    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2547
        :class:`~transformers.Pipeline`: A suitable pipeline for the task.
Lysandre Debut's avatar
Lysandre Debut committed
2548
2549
2550
2551
2552
2553

    Examples::

        from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer

        # Sentiment analysis pipeline
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2554
        pipeline('sentiment-analysis')
Lysandre Debut's avatar
Lysandre Debut committed
2555
2556

        # Question answering pipeline, specifying the checkpoint identifier
2557
        pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased')
Lysandre Debut's avatar
Lysandre Debut committed
2558
2559
2560
2561
2562

        # Named entity recognition pipeline, passing in a specific model and tokenizer
        model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
        tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        pipeline('ner', model=model, tokenizer=tokenizer)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2563
    """
2564
    # Retrieve the task
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2565
2566
2567
    if task not in SUPPORTED_TASKS:
        raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))

2568
    framework = framework or get_framework(model)
2569

Morgan Funtowicz's avatar
Morgan Funtowicz committed
2570
    targeted_task = SUPPORTED_TASKS[task]
2571
    task_class, model_class = targeted_task["impl"], targeted_task[framework]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2572

2573
    # Use default model/config/tokenizer for the task if no model is provided
2574
    if model is None:
2575
        model = targeted_task["default"]["model"][framework]
2576

2577
2578
    # Try to infer tokenizer from model or config name (if provided as str)
    if tokenizer is None:
2579
        if isinstance(model, str):
2580
            tokenizer = model
2581
        elif isinstance(config, str):
2582
2583
2584
            tokenizer = config
        else:
            # Impossible to guest what is the right tokenizer here
2585
2586
            raise Exception(
                "Impossible to guess which tokenizer to use. "
2587
                "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
2588
            )
2589

Lysandre Debut's avatar
Lysandre Debut committed
2590
    modelcard = None
2591
    # Try to infer modelcard from model or config name (if provided as str)
Lysandre Debut's avatar
Lysandre Debut committed
2592
2593
2594
2595
    if isinstance(model, str):
        modelcard = model
    elif isinstance(config, str):
        modelcard = config
2596
2597

    # Instantiate tokenizer if needed
2598
2599
2600
2601
2602
2603
    if isinstance(tokenizer, (str, tuple)):
        if isinstance(tokenizer, tuple):
            # For tuple we have (tokenizer name, {kwargs})
            tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
        else:
            tokenizer = AutoTokenizer.from_pretrained(tokenizer)
2604
2605
2606
2607

    # Instantiate config if needed
    if isinstance(config, str):
        config = AutoConfig.from_pretrained(config)
2608

thomwolf's avatar
thomwolf committed
2609
2610
2611
2612
    # Instantiate modelcard if needed
    if isinstance(modelcard, str):
        modelcard = ModelCard.from_pretrained(modelcard)

2613
    # Instantiate model if needed
2614
    if isinstance(model, str):
2615
2616
        # Handle transparent TF/PT model conversion
        model_kwargs = {}
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
        if framework == "pt" and model.endswith(".h5"):
            model_kwargs["from_tf"] = True
            logger.warning(
                "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
                "Trying to load the model with PyTorch."
            )
        elif framework == "tf" and model.endswith(".bin"):
            model_kwargs["from_pt"] = True
            logger.warning(
                "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
                "Trying to load the model with Tensorflow."
            )
2629
        model = model_class.from_pretrained(model, config=config, **model_kwargs)
2630

2631
    return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)