pipelines.py 119 KB
Newer Older
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

Morgan Funtowicz's avatar
Morgan Funtowicz committed
16

17
18
import csv
import json
Morgan Funtowicz's avatar
Morgan Funtowicz committed
19
import os
20
import pickle
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import sys
22
import uuid
23
import warnings
Morgan Funtowicz's avatar
Morgan Funtowicz committed
24
from abc import ABC, abstractmethod
25
from contextlib import contextmanager
26
from itertools import chain
27
from os.path import abspath, exists
28
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
29
from uuid import UUID
Morgan Funtowicz's avatar
Morgan Funtowicz committed
30
31
32

import numpy as np

33
from .configuration_auto import AutoConfig
34
35
from .configuration_utils import PretrainedConfig
from .data import SquadExample, squad_convert_examples_to_features
Sylvain Gugger's avatar
Sylvain Gugger committed
36
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available
37
38
39
40
from .modelcard import ModelCard
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
41
from .tokenization_utils_base import BatchEncoding, PaddingStrategy
Lysandre Debut's avatar
Lysandre Debut committed
42
from .utils import logging
Morgan Funtowicz's avatar
Morgan Funtowicz committed
43

Aymeric Augustin's avatar
Aymeric Augustin committed
44

Morgan Funtowicz's avatar
Morgan Funtowicz committed
45
if is_tf_available():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
46
    import tensorflow as tf
47

48
    from .modeling_tf_auto import (
49
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
50
        TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
51
52
53
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
        TF_MODEL_WITH_LM_HEAD_MAPPING,
54
        TFAutoModel,
55
        TFAutoModelForCausalLM,
56
        TFAutoModelForMaskedLM,
57
        TFAutoModelForQuestionAnswering,
58
        TFAutoModelForSeq2SeqLM,
59
        TFAutoModelForSequenceClassification,
60
61
        TFAutoModelForTokenClassification,
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
62
63
64

if is_torch_available():
    import torch
65

66
    from .modeling_auto import (
67
68
69
70
71
        MODEL_FOR_MASKED_LM_MAPPING,
        MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
        MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
72
        AutoModel,
73
74
        AutoModelForCausalLM,
        AutoModelForMaskedLM,
75
76
77
78
        AutoModelForQuestionAnswering,
        AutoModelForSeq2SeqLM,
        AutoModelForSequenceClassification,
        AutoModelForTokenClassification,
79
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
80

81
82
if TYPE_CHECKING:
    from .modeling_tf_utils import TFPreTrainedModel
83
    from .modeling_utils import PreTrainedModel
84

Morgan Funtowicz's avatar
Morgan Funtowicz committed
85

Lysandre Debut's avatar
Lysandre Debut committed
86
logger = logging.get_logger(__name__)
87

88

89
def get_framework(model):
Sylvain Gugger's avatar
Sylvain Gugger committed
90
91
92
93
    """
    Select framework (TensorFlow or PyTorch) to use.

    Args:
94
        model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
Sylvain Gugger's avatar
Sylvain Gugger committed
95
96
            If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
            the model name). If no specific model is provided, defaults to using PyTorch.
97
    """
98
    if not is_tf_available() and not is_torch_available():
Aymeric Augustin's avatar
Aymeric Augustin committed
99
        raise RuntimeError(
100
101
102
103
            "At least one of TensorFlow 2.0 or PyTorch should be installed. "
            "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
            "To install PyTorch, read the instructions at https://pytorch.org/."
        )
104
105
106
107
108
109
110
111
112
113
114
115
    if isinstance(model, str):
        if is_torch_available() and not is_tf_available():
            model = AutoModel.from_pretrained(model)
        elif is_tf_available() and not is_torch_available():
            model = TFAutoModel.from_pretrained(model)
        else:
            try:
                model = AutoModel.from_pretrained(model)
            except OSError:
                model = TFAutoModel.from_pretrained(model)

    framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
thomwolf's avatar
thomwolf committed
116
117
    return framework

118

119
def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str:
120
121
122
123
124
125
126
127
128
129
    """
    Select a default model to use for a given task. Defaults to pytorch if ambiguous.

    Args:
        targeted_task (:obj:`Dict` ):
           Dictionnary representing the given task, that should contain default models

        framework (:obj:`str`, None)
           "pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.

130
131
132
        task_options (:obj:`Any`, None)
           Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for translation task.

133
134
135
136
137
138
139
140
141
    Returns

        :obj:`str` The model string representing the default model for this pipeline
    """
    if is_torch_available() and not is_tf_available():
        framework = "pt"
    elif is_tf_available() and not is_torch_available():
        framework = "tf"

142
143
144
145
146
147
148
149
150
151
152
153
154
155
    defaults = targeted_task["default"]
    if task_options:
        if task_options not in defaults:
            raise ValueError("The task does not provide any default models for options {}".format(task_options))
        default_models = defaults[task_options]["model"]
    elif "model" in defaults:
        default_models = targeted_task["default"]["model"]
    else:
        # XXX This error message needs to be updated to be more generic if more tasks are going to become
        # parametrized
        raise ValueError(
            'The task defaults can\'t be correctly selectionned. You probably meant "translation_XX_to_YY"'
        )

156
157
158
159
160
161
    if framework is None:
        framework = "pt"

    return default_models[framework]


162
163
class PipelineException(Exception):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
164
165
166
167
168
169
    Raised by a :class:`~transformers.Pipeline` when handling __call__.

    Args:
        task (:obj:`str`): The task of the pipeline.
        model (:obj:`str`): The model used by the pipeline.
        reason (:obj:`str`): The error message to display.
170
171
172
173
174
175
176
177
178
    """

    def __init__(self, task: str, model: str, reason: str):
        super().__init__(reason)

        self.task = task
        self.model = model


179
180
class ArgumentHandler(ABC):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
181
    Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
182
    """
183

184
185
186
    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplementedError()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
187
188


189
190
class DefaultArgumentHandler(ArgumentHandler):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
191
    Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`.
192
    """
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    @staticmethod
    def handle_kwargs(kwargs: Dict) -> List:
        if len(kwargs) == 1:
            output = list(kwargs.values())
        else:
            output = list(chain(kwargs.values()))

        return DefaultArgumentHandler.handle_args(output)

    @staticmethod
    def handle_args(args: Sequence[Any]) -> List[str]:

        # Only one argument, let's do case by case
        if len(args) == 1:
            if isinstance(args[0], str):
209
                return [args[0]]
210
211
212
213
214
215
            elif not isinstance(args[0], list):
                return list(args)
            else:
                return args[0]

        # Multiple arguments (x1, x2, ...)
216
        elif len(args) > 1:
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            if all([isinstance(arg, str) for arg in args]):
                return list(args)

            # If not instance of list, then it should instance of iterable
            elif isinstance(args, Iterable):
                return list(chain.from_iterable(chain(args)))
            else:
                raise ValueError(
                    "Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
                )
        else:
            return []

    def __call__(self, *args, **kwargs):
        if len(kwargs) > 0 and len(args) > 0:
            raise ValueError("Pipeline cannot handle mixed args and kwargs")

        if len(kwargs) > 0:
            return DefaultArgumentHandler.handle_kwargs(kwargs)
        else:
            return DefaultArgumentHandler.handle_args(args)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
238
239


240
class PipelineDataFormat:
241
242
243
    """
    Base class for all the pipeline supported data format both for reading and writing.
    Supported data formats currently includes:
Sylvain Gugger's avatar
Sylvain Gugger committed
244
245
246
    - JSON
    - CSV
    - stdin/stdout (pipe)
247

Sylvain Gugger's avatar
Sylvain Gugger committed
248
249
250
251
252
253
254
255
256
    :obj:`PipelineDataFormat` also includes some utilities to work with multi-columns like mapping from datasets
    columns to pipelines keyword arguments through the :obj:`dataset_kwarg_1=dataset_column_1` format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
257
    """
258
259

    SUPPORTED_FORMATS = ["json", "csv", "pipe"]
260

261
    def __init__(
Lysandre's avatar
Lysandre committed
262
263
264
265
266
        self,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite: bool = False,
267
    ):
thomwolf's avatar
thomwolf committed
268
269
        self.output_path = output_path
        self.input_path = input_path
270
        self.column = column.split(",") if column is not None else [""]
271
272
273
        self.is_multi_columns = len(self.column) > 1

        if self.is_multi_columns:
274
            self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column]
275

thomwolf's avatar
thomwolf committed
276
        if output_path is not None and not overwrite:
thomwolf's avatar
thomwolf committed
277
            if exists(abspath(self.output_path)):
278
                raise OSError("{} already exists on disk".format(self.output_path))
279

thomwolf's avatar
thomwolf committed
280
281
        if input_path is not None:
            if not exists(abspath(self.input_path)):
282
                raise OSError("{} doesnt exist on disk".format(self.input_path))
283
284
285
286
287
288

    @abstractmethod
    def __iter__(self):
        raise NotImplementedError()

    @abstractmethod
Sylvain Gugger's avatar
Sylvain Gugger committed
289
    def save(self, data: Union[dict, List[dict]]):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
290
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
291
292
293
294
295
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
296
        """
297
298
        raise NotImplementedError()

299
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
300
301
        """
        Save the provided data object as a pickle-formatted binary data on the disk.
Sylvain Gugger's avatar
Sylvain Gugger committed
302
303
304
305
306
307

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.

        Returns:
            :obj:`str`: Path where the data has been saved.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
308
        """
thomwolf's avatar
thomwolf committed
309
        path, _ = os.path.splitext(self.output_path)
310
        binary_path = os.path.extsep.join((path, "pickle"))
311

312
        with open(binary_path, "wb+") as f_output:
313
314
315
316
            pickle.dump(data, f_output)

        return binary_path

317
    @staticmethod
318
    def from_str(
Lysandre's avatar
Lysandre committed
319
320
321
322
323
        format: str,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite=False,
Sylvain Gugger's avatar
Sylvain Gugger committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    ) -> "PipelineDataFormat":
        """
        Creates an instance of the right subclass of :class:`~transformers.pipelines.PipelineDataFormat` depending
        on :obj:`format`.

        Args:
            format: (:obj:`str`):
                The format of the desired pipeline. Acceptable values are :obj:`"json"`, :obj:`"csv"` or :obj:`"pipe"`.
            output_path (:obj:`str`, `optional`):
                Where to save the outgoing data.
            input_path (:obj:`str`, `optional`):
                Where to look for the input data.
            column (:obj:`str`, `optional`):
                The column to read.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to overwrite the :obj:`output_path`.

        Returns:
            :class:`~transformers.pipelines.PipelineDataFormat`: The proper data format.
        """
344
        if format == "json":
thomwolf's avatar
thomwolf committed
345
            return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
346
        elif format == "csv":
thomwolf's avatar
thomwolf committed
347
            return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
348
        elif format == "pipe":
thomwolf's avatar
thomwolf committed
349
            return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
350
        else:
351
            raise KeyError("Unknown reader {} (Available reader are json/csv/pipe)".format(format))
352
353
354


class CsvPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
355
356
357
358
359
360
361
362
363
364
365
    """
    Support for pipelines using CSV data format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

366
    def __init__(
Lysandre's avatar
Lysandre committed
367
368
369
370
371
        self,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite=False,
372
    ):
thomwolf's avatar
thomwolf committed
373
        super().__init__(output_path, input_path, column, overwrite=overwrite)
374
375

    def __iter__(self):
376
        with open(self.input_path, "r") as f:
377
378
379
380
381
            reader = csv.DictReader(f)
            for row in reader:
                if self.is_multi_columns:
                    yield {k: row[c] for k, c in self.column}
                else:
382
                    yield row[self.column[0]]
383
384

    def save(self, data: List[dict]):
Sylvain Gugger's avatar
Sylvain Gugger committed
385
386
387
388
389
390
391
        """
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`List[dict]`): The data to store.
        """
392
        with open(self.output_path, "w") as f:
393
394
395
396
397
398
399
            if len(data) > 0:
                writer = csv.DictWriter(f, list(data[0].keys()))
                writer.writeheader()
                writer.writerows(data)


class JsonPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
400
401
402
403
404
405
406
407
408
409
410
    """
    Support for pipelines using JSON file format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

411
    def __init__(
Lysandre's avatar
Lysandre committed
412
413
414
415
416
        self,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite=False,
417
    ):
thomwolf's avatar
thomwolf committed
418
        super().__init__(output_path, input_path, column, overwrite=overwrite)
419

420
        with open(input_path, "r") as f:
421
422
423
424
425
426
427
            self._entries = json.load(f)

    def __iter__(self):
        for entry in self._entries:
            if self.is_multi_columns:
                yield {k: entry[c] for k, c in self.column}
            else:
428
                yield entry[self.column[0]]
429
430

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
431
432
433
434
435
436
        """
        Save the provided data object in a json file.

        Args:
            data (:obj:`dict`): The data to store.
        """
437
        with open(self.output_path, "w") as f:
438
439
440
            json.dump(data, f)


Morgan Funtowicz's avatar
Morgan Funtowicz committed
441
442
443
444
445
446
class PipedPipelineDataFormat(PipelineDataFormat):
    """
    Read data from piped input to the python process.
    For multi columns data, columns should separated by \t

    If columns are provided, then the output will be a dictionary with {column_x: value_x}
Sylvain Gugger's avatar
Sylvain Gugger committed
447
448
449
450
451
452
453

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
454
    """
455

Morgan Funtowicz's avatar
Morgan Funtowicz committed
456
457
458
    def __iter__(self):
        for line in sys.stdin:
            # Split for multi-columns
459
            if "\t" in line:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
460

461
                line = line.split("\t")
Morgan Funtowicz's avatar
Morgan Funtowicz committed
462
463
464
465
466
467
468
469
470
471
472
                if self.column:
                    # Dictionary to map arguments
                    yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
                else:
                    yield tuple(line)

            # No dictionary to map arguments
            else:
                yield line

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
473
474
475
476
477
478
        """
        Print the data.

        Args:
            data (:obj:`dict`): The data to store.
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
479
480
        print(data)

481
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
thomwolf's avatar
thomwolf committed
482
        if self.output_path is None:
483
            raise KeyError(
484
485
                "When using piped input on pipeline outputting large object requires an output file path. "
                "Please provide such output path through --output argument."
486
487
488
489
            )

        return super().save_binary(data)

Morgan Funtowicz's avatar
Morgan Funtowicz committed
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504

class _ScikitCompat(ABC):
    """
    Interface layer for the Scikit and Keras compatibility.
    """

    @abstractmethod
    def transform(self, X):
        raise NotImplementedError()

    @abstractmethod
    def predict(self, X):
        raise NotImplementedError()


Sylvain Gugger's avatar
Sylvain Gugger committed
505
PIPELINE_INIT_ARGS = r"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
506
    Arguments:
507
508
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
509
510
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
511
512
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
513
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
514
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
515
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
516
517
518
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
519
520

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
521
522
523
524
525
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
526
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
527
528
        device (:obj:`int`, `optional`, defaults to -1):
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model
Morgan Funtowicz's avatar
Morgan Funtowicz committed
529
            on the associated CUDA device id.
Lysandre Debut's avatar
Lysandre Debut committed
530
        binary_output (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
531
532
            Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
533

Lysandre Debut's avatar
Lysandre Debut committed
534

Sylvain Gugger's avatar
Sylvain Gugger committed
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
@add_end_docstrings(PIPELINE_INIT_ARGS)
class Pipeline(_ScikitCompat):
    """
    The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
    different pipelines.

    Base class implementing pipelined operations.
    Pipeline workflow is defined as a sequence of the following operations:

        Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output

    Pipeline supports running on CPU or GPU through the device argument (see below).

    Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` )
    output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
    provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the
    pickle format.
552
    """
thomwolf's avatar
thomwolf committed
553
554
555

    default_input_names = None

556
557
    def __init__(
        self,
558
559
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
560
        modelcard: Optional[ModelCard] = None,
561
        framework: Optional[str] = None,
562
        task: str = "",
563
564
565
566
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
    ):
567

thomwolf's avatar
thomwolf committed
568
        if framework is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
569
            framework = get_framework(model)
thomwolf's avatar
thomwolf committed
570

571
        self.task = task
572
573
        self.model = model
        self.tokenizer = tokenizer
574
        self.modelcard = modelcard
thomwolf's avatar
thomwolf committed
575
        self.framework = framework
576
        self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
577
        self.binary_output = binary_output
578
579
        self._args_parser = args_parser or DefaultArgumentHandler()

580
        # Special handling
581
582
        if self.framework == "pt" and self.device.type == "cuda":
            self.model = self.model.to(self.device)
583

584
585
586
587
588
        # Update config with task specific parameters
        task_specific_params = self.model.config.task_specific_params
        if task_specific_params is not None and task in task_specific_params:
            self.model.config.update(task_specific_params.get(task))

Sylvain Gugger's avatar
Sylvain Gugger committed
589
    def save_pretrained(self, save_directory: str):
590
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
591
592
593
594
595
        Save the pipeline's model and tokenizer.

        Args:
            save_directory (:obj:`str`):
                A path to the directory where to saved. It will be created if it doesn't exist.
596
        """
597
598
        if os.path.isfile(save_directory):
            logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
599
            return
600
        os.makedirs(save_directory, exist_ok=True)
601
602
603

        self.model.save_pretrained(save_directory)
        self.tokenizer.save_pretrained(save_directory)
604
605
        if self.modelcard is not None:
            self.modelcard.save_pretrained(save_directory)
606
607

    def transform(self, X):
608
609
610
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
611
612
613
        return self(X=X)

    def predict(self, X):
614
615
616
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
617
        return self(X=X)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
618

619
620
    @contextmanager
    def device_placement(self):
621
622
        """
        Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
Sylvain Gugger's avatar
Sylvain Gugger committed
623

624
625
        Returns:
            Context manager
Sylvain Gugger's avatar
Sylvain Gugger committed
626
627
628
629
630
631
632
633

        Examples::

            # Explicitly ask for tensor allocation on CUDA device :0
            pipe = pipeline(..., device=0)
            with pipe.device_placement():
                # Every framework specific tensor allocation will be done on the request device
                output = pipe(...)
634
        """
635
636
        if self.framework == "tf":
            with tf.device("/CPU:0" if self.device == -1 else "/device:GPU:{}".format(self.device)):
637
638
                yield
        else:
639
            if self.device.type == "cuda":
640
                torch.cuda.set_device(self.device)
641

642
            yield
643

644
645
646
    def ensure_tensor_on_device(self, **inputs):
        """
        Ensure PyTorch tensors are on the specified device.
Sylvain Gugger's avatar
Sylvain Gugger committed
647
648
649
650
651
652

        Args:
            inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.

        Return:
            :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
653
654
655
        """
        return {name: tensor.to(self.device) for name, tensor in inputs.items()}

Sylvain Gugger's avatar
Sylvain Gugger committed
656
    def check_model_type(self, supported_models: Union[List[str], dict]):
657
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
658
659
660
661
662
        Check if the model class is in supported by the pipeline.

        Args:
            supported_models (:obj:`List[str]` or :obj:`dict`):
                The list of models supported by the pipeline, or a dictionary with model class values.
663
664
665
666
667
668
669
670
671
672
        """
        if not isinstance(supported_models, list):  # Create from a model mapping
            supported_models = [item[1].__name__ for item in supported_models.items()]
        if self.model.__class__.__name__ not in supported_models:
            raise PipelineException(
                self.task,
                self.model.base_model_prefix,
                f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
            )

673
    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
674
675
676
        """
        Parse arguments and tokenize
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
677
        # Parse arguments
678
        inputs = self._args_parser(*args, **kwargs)
679
        inputs = self.tokenizer(
Lysandre's avatar
Lysandre committed
680
681
682
683
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
684
        )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
685

Julien Chaumond's avatar
Julien Chaumond committed
686
687
        return inputs

688
689
    def __call__(self, *args, **kwargs):
        inputs = self._parse_and_tokenize(*args, **kwargs)
690
        return self._forward(inputs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
691

Julien Chaumond's avatar
Julien Chaumond committed
692
    def _forward(self, inputs, return_tensors=False):
693
694
695
696
        """
        Internal framework specific forward dispatching.
        Args:
            inputs: dict holding all the keyworded arguments for required by the model forward method.
Julien Chaumond's avatar
Julien Chaumond committed
697
            return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array.
698
699
700
        Returns:
            Numpy array
        """
701
702
703
704
        # Encode for forward
        with self.device_placement():
            if self.framework == "tf":
                # TODO trace model
Funtowicz Morgan's avatar
Funtowicz Morgan committed
705
                predictions = self.model(inputs.data, training=False)[0]
706
707
708
709
            else:
                with torch.no_grad():
                    inputs = self.ensure_tensor_on_device(**inputs)
                    predictions = self.model(**inputs)[0].cpu()
710

Julien Chaumond's avatar
Julien Chaumond committed
711
712
713
714
        if return_tensors:
            return predictions
        else:
            return predictions.numpy()
715
716


Sylvain Gugger's avatar
Sylvain Gugger committed
717
# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output`
718
class FeatureExtractionPipeline(Pipeline):
719
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
720
721
    Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
    transformer, which can be used as features in downstream tasks.
Lysandre Debut's avatar
Lysandre Debut committed
722

Sylvain Gugger's avatar
Sylvain Gugger committed
723
724
    This feature extraction pipeline can currently be loaded from :func:`~transformers.pipeline` using the task
    identifier: :obj:`"feature-extraction"`.
Lysandre Debut's avatar
Lysandre Debut committed
725
726
727
728
729

    All models may be used for this pipeline. See a list of all models, including community-contributed models on
    `huggingface.co/models <https://huggingface.co/models>`__.

    Arguments:
730
731
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
732
733
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
734
735
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
736
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
737
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
738
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
739
740
741
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
742
743

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
744
745
746
747
748
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
749
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
750
751
        device (:obj:`int`, `optional`, defaults to -1):
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model
Lysandre Debut's avatar
Lysandre Debut committed
752
            on the associated CUDA device id.
753
    """
754

755
756
    def __init__(
        self,
757
758
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
759
        modelcard: Optional[ModelCard] = None,
760
761
762
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
763
        task: str = "",
764
765
766
767
768
769
770
771
772
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
773
            task=task,
774
        )
775

776
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
777
778
779
780
781
782
783
784
785
        """
        Extract the features of the input(s).

        Args:
            args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of texts) to get the features of.

        Return:
            A nested list of :obj:`float`: The features computed by the model.
        """
786
        return super().__call__(*args, **kwargs).tolist()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
787
788


Sylvain Gugger's avatar
Sylvain Gugger committed
789
@add_end_docstrings(PIPELINE_INIT_ARGS)
790
791
class TextGenerationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
792
793
    Language generation pipeline using any :obj:`ModelWithLMHead`. This pipeline predicts the words that will follow a
    specified text prompt.
794

Sylvain Gugger's avatar
Sylvain Gugger committed
795
796
    This language generation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"text-generation"`.
797

Sylvain Gugger's avatar
Sylvain Gugger committed
798
799
    The models that this pipeline can use are models that have been trained with an autoregressive language modeling
    objective, which includes the uni-directional models in the library (e.g. gpt2).
800
    See the list of available community models on
801
    `huggingface.co/models <https://huggingface.co/models?filter=causal-lm>`__.
802
803
    """

804
    # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
805
806
    # in https://github.com/rusiaaman/XLNet-gen#methodology
    # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
807

808
    XL_PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
809
810
811
812
813
814
815
816
    (except for Alexei and Maria) are discovered.
    The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
    remainder of the story. 1883 Western Siberia,
    a young Grigori Rasputin is asked by his father and a group of men to perform magic.
    Rasputin has a vision and denounces one of the men as a horse thief. Although his
    father initially slaps him for making such an accusation, Rasputin watches as the
    man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
    the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
817
    with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
818

819
820
821
822
823
824
825
826
827
828
829
830
831
832
    ALLOWED_MODELS = [
        "XLNetLMHeadModel",
        "TransfoXLLMHeadModel",
        "ReformerModelWithLMHead",
        "GPT2LMHeadModel",
        "OpenAIGPTLMHeadModel",
        "CTRLLMHeadModel",
        "TFXLNetLMHeadModel",
        "TFTransfoXLLMHeadModel",
        "TFGPT2LMHeadModel",
        "TFOpenAIGPTLMHeadModel",
        "TFCTRLLMHeadModel",
    ]

833
834
835
836
837
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.check_model_type(self.ALLOWED_MODELS)

838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
    # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments

    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
        """
        Parse arguments and tokenize
        """
        # Parse arguments
        if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
            tokenizer_kwargs = {"add_space_before_punct_symbol": True}
        else:
            tokenizer_kwargs = {}
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer(
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            **tokenizer_kwargs,
        )

        return inputs

860
    def __call__(
861
862
863
864
865
866
867
        self,
        *args,
        return_tensors=False,
        return_text=True,
        clean_up_tokenization_spaces=False,
        prefix=None,
        **generate_kwargs
868
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
869
870
871
872
873
874
875
876
877
878
879
880
        """
        Complete the prompt(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several prompts (or one list of prompts) to complete.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
881
882
            prefix (:obj:`str`, `optional`):
                Prefix added to prompt.
Sylvain Gugger's avatar
Sylvain Gugger committed
883
884
885
886
887
888
889
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
890

Sylvain Gugger's avatar
Sylvain Gugger committed
891
892
893
894
            - **generated_text** (:obj:`str`, present when ``return_text=True``) -- The generated text.
            - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the generated text.
        """
895
        text_inputs = self._args_parser(*args)
896
897
898
899
900

        results = []
        for prompt_text in text_inputs:
            # Manage correct placement of the tensors
            with self.device_placement():
901
902
                prefix = prefix if prefix is not None else self.model.config.prefix
                if prefix is None and self.model.__class__.__name__ in [
903
904
905
906
907
                    "XLNetLMHeadModel",
                    "TransfoXLLMHeadModel",
                    "TFXLNetLMHeadModel",
                    "TFTransfoXLLMHeadModel",
                ]:
908
909
910
911
912
                    # For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
                    prefix = self.XL_PREFIX

                if prefix:
                    prefix_inputs = self._parse_and_tokenize(prefix, padding=False, add_special_tokens=False)
913
                    # This impacts max_length and min_length argument that need adjusting.
914
915
916
917
918
919
920
921
                    prefix_length = prefix_inputs["input_ids"].shape[-1]
                    if generate_kwargs.get("max_length", None) is not None:
                        generate_kwargs["max_length"] += prefix_length
                    if generate_kwargs.get("min_length", None) is not None:
                        generate_kwargs["min_length"] += prefix_length

                prefix = prefix or ""
                inputs = self._parse_and_tokenize(prefix + prompt_text, padding=False, add_special_tokens=False)
922

923
924
925
926
927
928
                # set input_ids to None to allow empty prompt
                if inputs["input_ids"].shape[-1] == 0:
                    inputs["input_ids"] = None
                    inputs["attention_mask"] = None

                if self.framework == "pt" and inputs["input_ids"] is not None:
929
930
931
932
933
934
                    inputs = self.ensure_tensor_on_device(**inputs)

                input_ids = inputs["input_ids"]

                # Ensure that batch size = 1 (batch generation not allowed for now)
                assert (
935
                    input_ids is None or input_ids.shape[0] == 1
936
937
938
939
940
941
                ), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."

                output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs)  # BS x SL

            result = []
            for generated_sequence in output_sequences:
942
943
                if self.framework == "pt" and generated_sequence is not None:
                    generated_sequence = generated_sequence.cpu()
944
                generated_sequence = generated_sequence.numpy().tolist()
945
946
947
948
949
950
951
952
953
954
955
956
                record = {}
                if return_tensors:
                    record["generated_token_ids"] = generated_sequence
                if return_text:
                    # Decode text
                    text = self.tokenizer.decode(
                        generated_sequence,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )

                    # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
957
958
959
960
961
962
963
964
965
966
967
968
                    if input_ids is None:
                        prompt_length = 0
                    else:
                        prompt_length = len(
                            self.tokenizer.decode(
                                input_ids[0],
                                skip_special_tokens=True,
                                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                            )
                        )

                    record["generated_text"] = prompt_text + text[prompt_length:]
969
970
971
972
973
974
975
976
977
978

                result.append(record)
            results += [result]

        if len(results) == 1:
            return results[0]

        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
979
980
981
982
983
984
985
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        return_all_scores (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether to return all prediction scores or just the one of the predicted class.
    """,
)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
986
class TextClassificationPipeline(Pipeline):
987
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
988
989
    Text classification pipeline using any :obj:`ModelForSequenceClassification`. See the
    `sequence classification examples <../task_summary.html#sequence-classification>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
990

Sylvain Gugger's avatar
Sylvain Gugger committed
991
992
993
    This text classification pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"sentiment-analysis"` (for classifying sequences according to positive or negative
    sentiments).
Lysandre Debut's avatar
Lysandre Debut committed
994

995
996
997
    If multiple classification labels are available (:obj:`model.config.num_labels >= 2`), the pipeline will run
    a softmax over the results. If there is a single label, the pipeline will run a sigmoid over the result.

Lysandre Debut's avatar
Lysandre Debut committed
998
    The models that this pipeline can use are models that have been fine-tuned on a sequence classification task.
999
1000
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=text-classification>`__.
1001
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1002

1003
1004
1005
    def __init__(self, return_all_scores: bool = False, **kwargs):
        super().__init__(**kwargs)

1006
1007
1008
1009
1010
1011
        self.check_model_type(
            TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
        )

1012
1013
        self.return_all_scores = return_all_scores

1014
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1015
1016
1017
1018
1019
        """
        Classify the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
Yuta Hayashibe's avatar
Yuta Hayashibe committed
1020
                One or several texts (or one list of prompts) to classify.
Sylvain Gugger's avatar
Sylvain Gugger committed
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
            following keys:

            - **label** (:obj:`str`) -- The label predicted.
            - **score** (:obj:`float`) -- The corresponding probability.

            If ``self.return_all_scores=True``, one such dictionary is returned per label.
        """
1031
        outputs = super().__call__(*args, **kwargs)
1032
1033
1034
1035
1036

        if self.model.config.num_labels == 1:
            scores = 1.0 / (1.0 + np.exp(-outputs))
        else:
            scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
1037
1038
        if self.return_all_scores:
            return [
1039
                [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(item)]
1040
1041
1042
1043
1044
1045
                for item in scores
            ]
        else:
            return [
                {"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores
            ]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1046
1047


1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
class ZeroShotClassificationArgumentHandler(ArgumentHandler):
    """
    Handles arguments for zero-shot for text classification by turning each possible label into an NLI
    premise/hypothesis pair.
    """

    def _parse_labels(self, labels):
        if isinstance(labels, str):
            labels = [label.strip() for label in labels.split(",")]
        return labels

    def __call__(self, sequences, labels, hypothesis_template):
        if len(labels) == 0 or len(sequences) == 0:
            raise ValueError("You must include at least one label and at least one sequence.")
        if hypothesis_template.format(labels[0]) == hypothesis_template:
            raise ValueError(
                (
                    'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
                    "Make sure the passed template includes formatting syntax such as {{}} where the label should go."
                ).format(hypothesis_template)
            )

        if isinstance(sequences, str):
            sequences = [sequences]
        labels = self._parse_labels(labels)

        sequence_pairs = []
        for sequence in sequences:
            sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])

        return sequence_pairs


Sylvain Gugger's avatar
Sylvain Gugger committed
1081
@add_end_docstrings(PIPELINE_INIT_ARGS)
1082
1083
class ZeroShotClassificationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1084
1085
    NLI-based zero-shot classification pipeline using a :obj:`ModelForSequenceClassification` trained on NLI (natural
    language inference) tasks.
1086
1087

    Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
Sylvain Gugger's avatar
Sylvain Gugger committed
1088
    pair and passed to the pretrained model. Then, the logit for `entailment` is taken as the logit for the
1089
1090
1091
    candidate label being valid. Any NLI model can be used as long as the first output logit corresponds to
    `contradiction` and the last to `entailment`.

Sylvain Gugger's avatar
Sylvain Gugger committed
1092
1093
    This NLI pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"zero-shot-classification"`.
1094

Sylvain Gugger's avatar
Sylvain Gugger committed
1095
    The models that this pipeline can use are models that have been fine-tuned on an NLI task.
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?search=nli>`__.
    """

    def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
        super().__init__(*args, args_parser=args_parser, **kwargs)

    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
        """
        Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
        """
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer(
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            truncation="only_first",
        )

        return inputs

    def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1120
        Classify the sequence(s) given as inputs.
1121
1122

        Args:
1123
            sequences (:obj:`str` or :obj:`List[str]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1124
                The sequence(s) to classify, will be truncated if the model input is too large.
1125
            candidate_labels (:obj:`str` or :obj:`List[str]`):
1126
1127
                The set of possible class labels to classify each sequence into. Can be a single label, a string of
                comma-separated labels, or a list of labels.
1128
            hypothesis_template (:obj:`str`, `optional`, defaults to :obj:`"This example is {}."`):
1129
1130
                The template used to turn each label into an NLI-style hypothesis. This template must include a {}
                or similar syntax for the candidate label to be inserted into the template. For example, the default
Sylvain Gugger's avatar
Sylvain Gugger committed
1131
1132
1133
1134
                template is :obj:`"This example is {}."` With the candidate label :obj:`"sports"`, this would be fed
                into the model like :obj:`"<cls> sequence to classify <sep> This example is sports . <sep>"`. The
                default template works well in many cases, but it may be worthwhile to experiment with different
                templates depending on the task setting.
1135
            multi_class (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1136
1137
1138
                Whether or not multiple candidate labels can be true. If :obj:`False`, the scores are normalized
                such that the sum of the label likelihoods for each sequence is 1. If :obj:`True`, the labels are
                considered independent and probabilities are normalized for each candidate by doing a softmax of
1139
                the entailment score vs. the contradiction score.
1140

Sylvain Gugger's avatar
Sylvain Gugger committed
1141
1142
1143
1144
1145
1146
        Return:
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:

            - **sequence** (:obj:`str`) -- The sequence for which this is the output.
            - **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
1147
            - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
        """
        outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
        num_sequences = 1 if isinstance(sequences, str) else len(sequences)
        candidate_labels = self._args_parser._parse_labels(candidate_labels)
        reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))

        if len(candidate_labels) == 1:
            multi_class = True

        if not multi_class:
            # softmax the "entailment" logits over all candidate labels
            entail_logits = reshaped_outputs[..., -1]
            scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
        else:
            # softmax over the entailment vs. contradiction dim for each label independently
            entail_contr_logits = reshaped_outputs[..., [0, -1]]
            scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
            scores = scores[..., 1]

        result = []
        for iseq in range(num_sequences):
            top_inds = list(reversed(scores[iseq].argsort()))
            result.append(
                {
1172
                    "sequence": sequences if isinstance(sequences, str) else sequences[iseq],
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
                    "labels": [candidate_labels[i] for i in top_inds],
                    "scores": scores[iseq][top_inds].tolist(),
                }
            )

        if len(result) == 1:
            return result[0]
        return result


Sylvain Gugger's avatar
Sylvain Gugger committed
1183
1184
1185
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
1186
        top_k (:obj:`int`, defaults to 5): The number of predictions to return.
Sylvain Gugger's avatar
Sylvain Gugger committed
1187
1188
    """,
)
Julien Chaumond's avatar
Julien Chaumond committed
1189
1190
class FillMaskPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1191
1192
    Masked language modeling prediction pipeline using any :obj:`ModelWithLMHead`. See the
    `masked language modeling examples <../task_summary.html#masked-language-modeling>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1193

Sylvain Gugger's avatar
Sylvain Gugger committed
1194
1195
    This mask filling pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"fill-mask"`.
Lysandre Debut's avatar
Lysandre Debut committed
1196
1197
1198

    The models that this pipeline can use are models that have been trained with a masked language modeling objective,
    which includes the bi-directional models in the library.
1199
    See the up-to-date list of available models on
1200
    `huggingface.co/models <https://huggingface.co/models?filter=masked-lm>`__.
Lysandre Debut's avatar
Lysandre Debut committed
1201

Sylvain Gugger's avatar
Sylvain Gugger committed
1202
    .. note::
Lysandre Debut's avatar
Lysandre Debut committed
1203

Sylvain Gugger's avatar
Sylvain Gugger committed
1204
        This pipeline only works for inputs with exactly one token masked.
Julien Chaumond's avatar
Julien Chaumond committed
1205
1206
1207
1208
    """

    def __init__(
        self,
1209
1210
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1211
        modelcard: Optional[ModelCard] = None,
Julien Chaumond's avatar
Julien Chaumond committed
1212
1213
1214
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
1215
        top_k=5,
1216
        task: str = "",
1217
        **kwargs
Julien Chaumond's avatar
Julien Chaumond committed
1218
1219
1220
1221
1222
1223
1224
1225
1226
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
1227
            task=task,
Julien Chaumond's avatar
Julien Chaumond committed
1228
1229
        )

1230
        self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
1231

1232
1233
1234
1235
1236
1237
1238
1239
        if "topk" in kwargs:
            warnings.warn(
                "The `topk` argument is deprecated and will be removed in a future version, use `top_k` instead.",
                FutureWarning,
            )
            self.top_k = kwargs.pop("topk")
        else:
            self.top_k = top_k
Julien Chaumond's avatar
Julien Chaumond committed
1240

1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
    def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
        numel = np.prod(masked_index.shape)
        if numel > 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
            )
        elif numel < 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"No mask_token ({self.tokenizer.mask_token}) found on the input",
            )

1256
    def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1257
1258
1259
1260
        """
        Fill the masked token in the text(s) given as inputs.

        Args:
1261
1262
1263
1264
1265
1266
            args (:obj:`str` or :obj:`List[str]`):
                One or several texts (or one list of prompts) with masked tokens.
            targets (:obj:`str` or :obj:`List[str]`, `optional`):
                When passed, the model will return the scores for the passed token or tokens rather than the top k
                predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
                be tokenized and the first resulting token will be used (with a warning).
1267
1268
            top_k (:obj:`int`, `optional`):
                When passed, overrides the number of predictions to return.
Sylvain Gugger's avatar
Sylvain Gugger committed
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
            following keys:

            - **sequence** (:obj:`str`) -- The corresponding input with the mask token prediction.
            - **score** (:obj:`float`) -- The corresponding probability.
            - **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
            - **token** (:obj:`str`) -- The predicted token (to replace the masked one).
        """
Julien Chaumond's avatar
Julien Chaumond committed
1279
1280
1281
1282
1283
1284
        inputs = self._parse_and_tokenize(*args, **kwargs)
        outputs = self._forward(inputs, return_tensors=True)

        results = []
        batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)

1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
        if targets is not None:
            if len(targets) == 0 or len(targets[0]) == 0:
                raise ValueError("At least one target must be provided when passed.")
            if isinstance(targets, str):
                targets = [targets]

            targets_proc = []
            for target in targets:
                target_enc = self.tokenizer.tokenize(target)
                if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
                    logger.warning(
                        "The specified target token `{}` does not exist in the model vocabulary. Replacing with `{}`.".format(
                            target, target_enc[0]
                        )
                    )
                targets_proc.append(target_enc[0])
            target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))

Julien Chaumond's avatar
Julien Chaumond committed
1303
1304
1305
1306
1307
        for i in range(batch_size):
            input_ids = inputs["input_ids"][i]
            result = []

            if self.framework == "tf":
1308
1309
1310
1311
1312
1313
                masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index)

                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1314
                probs = tf.nn.softmax(logits)
1315
                if targets is None:
1316
                    topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
1317
1318
1319
1320
1321
1322
                    values, predictions = topk.values.numpy(), topk.indices.numpy()
                else:
                    values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
                    sort_inds = tf.reverse(tf.argsort(values), [0])
                    values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
                    predictions = target_inds[sort_inds.numpy()]
Julien Chaumond's avatar
Julien Chaumond committed
1323
            else:
1324
                masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
1325
1326
1327

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index.numpy())
1328

1329
                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1330
                probs = logits.softmax(dim=0)
1331
                if targets is None:
1332
                    values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
1333
1334
1335
1336
1337
                else:
                    values = probs[..., target_inds]
                    sort_inds = list(reversed(values.argsort(dim=-1)))
                    values = values[..., sort_inds]
                    predictions = target_inds[sort_inds]
Julien Chaumond's avatar
Julien Chaumond committed
1338
1339
1340
1341
1342
1343

            for v, p in zip(values.tolist(), predictions.tolist()):
                tokens = input_ids.numpy()
                tokens[masked_index] = p
                # Filter padding out:
                tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
1344
1345
1346
1347
1348
1349
1350
1351
                result.append(
                    {
                        "sequence": self.tokenizer.decode(tokens),
                        "score": v,
                        "token": p,
                        "token_str": self.tokenizer.convert_ids_to_tokens(p),
                    }
                )
Julien Chaumond's avatar
Julien Chaumond committed
1352
1353
1354
1355
1356
1357
1358
1359
1360

            # Append
            results += [result]

        if len(results) == 1:
            return results[0]
        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
1361
1362
1363
1364
1365
1366
1367
1368
1369
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`):
            A list of labels to ignore.
        grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to group the tokens corresponding to the same entity together in the predictions or not.
    """,
)
1370
class TokenClassificationPipeline(Pipeline):
1371
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1372
1373
    Named Entity Recognition pipeline using any :obj:`ModelForTokenClassification`. See the
    `named entity recognition examples <../task_summary.html#named-entity-recognition>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1374

Sylvain Gugger's avatar
Sylvain Gugger committed
1375
1376
1377
    This token recognition pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location
    or miscellaneous).
Lysandre Debut's avatar
Lysandre Debut committed
1378
1379

    The models that this pipeline can use are models that have been fine-tuned on a token classification task.
1380
1381
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=token-classification>`__.
1382
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1383

1384
1385
1386
1387
    default_input_names = "sequences"

    def __init__(
        self,
1388
1389
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1390
        modelcard: Optional[ModelCard] = None,
1391
1392
1393
1394
1395
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
        ignore_labels=["O"],
1396
        task: str = "",
1397
        grouped_entities: bool = False,
1398
1399
1400
1401
1402
1403
1404
1405
1406
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=binary_output,
1407
            task=task,
1408
        )
1409

1410
1411
1412
1413
1414
1415
        self.check_model_type(
            TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
        )

1416
        self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
1417
        self.ignore_labels = ignore_labels
1418
        self.grouped_entities = grouped_entities
1419

1420
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
        """
        Classify each token of the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several texts (or one list of texts) for token classification.

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in
            the corresponding input, or each entity if this pipeline was instantiated with
            :obj:`grouped_entities=True`) with the following keys:

            - **word** (:obj:`str`) -- The token/word classified.
            - **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
            - **entity** (:obj:`str`) -- The entity predicted for that token/word.
            - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
              corresponding token in the sentence.
        """
1439
        inputs = self._args_parser(*args, **kwargs)
Julien Chaumond's avatar
Julien Chaumond committed
1440
        answers = []
1441
        for sentence in inputs:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1442

1443
1444
            # Manage correct placement of the tensors
            with self.device_placement():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1445

1446
                tokens = self.tokenizer(
Lysandre's avatar
Lysandre committed
1447
1448
1449
1450
                    sentence,
                    return_attention_mask=False,
                    return_tensors=self.framework,
                    truncation=True,
1451
                )
1452
1453

                # Forward
1454
                if self.framework == "tf":
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1455
                    entities = self.model(tokens.data)[0][0].numpy()
1456
                    input_ids = tokens["input_ids"].numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1457
                else:
1458
                    with torch.no_grad():
1459
                        tokens = self.ensure_tensor_on_device(**tokens)
1460
                        entities = self.model(**tokens)[0][0].cpu().numpy()
1461
                        input_ids = tokens["input_ids"].cpu().numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1462

thomwolf's avatar
thomwolf committed
1463
1464
            score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
            labels_idx = score.argmax(axis=-1)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1465

1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
            entities = []
            # Filter to labels not in `self.ignore_labels`
            filtered_labels_idx = [
                (idx, label_idx)
                for idx, label_idx in enumerate(labels_idx)
                if self.model.config.id2label[label_idx] not in self.ignore_labels
            ]

            for idx, label_idx in filtered_labels_idx:

                entity = {
                    "word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
                    "score": score[idx][label_idx].item(),
                    "entity": self.model.config.id2label[label_idx],
                    "index": idx,
                }

                entities += [entity]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1484

1485
            # Append grouped entities
1486
            if self.grouped_entities:
1487
1488
                answers += [self.group_entities(entities)]
            # Append ungrouped entities
1489
1490
1491
            else:
                answers += [entities]

thomwolf's avatar
thomwolf committed
1492
1493
        if len(answers) == 1:
            return answers[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1494
1495
        return answers

1496
    def group_sub_entities(self, entities: List[dict]) -> dict:
1497
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1498
1499
1500
1501
        Group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1502
        """
1503
1504
        # Get the first entity in the entity group
        entity = entities[0]["entity"]
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
        scores = np.mean([entity["score"] for entity in entities])
        tokens = [entity["word"] for entity in entities]

        entity_group = {
            "entity_group": entity,
            "score": np.mean(scores),
            "word": self.tokenizer.convert_tokens_to_string(tokens),
        }
        return entity_group

1515
1516
    def group_entities(self, entities: List[dict]) -> List[dict]:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1517
1518
1519
1520
        Find and group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
        """

        entity_groups = []
        entity_group_disagg = []

        if entities:
            last_idx = entities[-1]["index"]

        for entity in entities:
            is_last_idx = entity["index"] == last_idx
            if not entity_group_disagg:
                entity_group_disagg += [entity]
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
                continue

            # If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
            # The split is meant to account for the "B" and "I" suffixes
            if (
                entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
                and entity["index"] == entity_group_disagg[-1]["index"] + 1
            ):
                entity_group_disagg += [entity]
                # Group the entities at the last entity
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
            # If the current entity is different from the previous entity, aggregate the disaggregated entity group
            else:
                entity_groups += [self.group_sub_entities(entity_group_disagg)]
                entity_group_disagg = [entity]
                # If it's the last entity, add it to the entity groups
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]

        return entity_groups

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1557

1558
NerPipeline = TokenClassificationPipeline
1559
1560


1561
1562
1563
class QuestionAnsweringArgumentHandler(ArgumentHandler):
    """
    QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped
Sylvain Gugger's avatar
Sylvain Gugger committed
1564
    to internal :class:`~transformers.SquadExample`.
1565

Sylvain Gugger's avatar
Sylvain Gugger committed
1566
1567
    QuestionAnsweringArgumentHandler manages all the possible to create a :class:`~transformers.SquadExample` from
    the command-line supplied arguments.
1568
    """
1569

1570
1571
1572
1573
    def __call__(self, *args, **kwargs):
        # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
        if args is not None and len(args) > 0:
            if len(args) == 1:
1574
                kwargs["X"] = args[0]
1575
            else:
1576
                kwargs["X"] = list(args)
1577

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1578
1579
        # Generic compatibility with sklearn and Keras
        # Batched data
1580
1581
        if "X" in kwargs or "data" in kwargs:
            inputs = kwargs["X"] if "X" in kwargs else kwargs["data"]
1582

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1583
1584
1585
1586
1587
            if isinstance(inputs, dict):
                inputs = [inputs]
            else:
                # Copy to avoid overriding arguments
                inputs = [i for i in inputs]
1588

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1589
            for i, item in enumerate(inputs):
1590
                if isinstance(item, dict):
1591
1592
                    if any(k not in item for k in ["question", "context"]):
                        raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
1593

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1594
1595
1596
                    inputs[i] = QuestionAnsweringPipeline.create_sample(**item)

                elif not isinstance(item, SquadExample):
1597
                    raise ValueError(
1598
1599
1600
                        "{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format(
                            "X" if "X" in kwargs else "data"
                        )
1601
1602
1603
                    )

            # Tabular input
1604
1605
1606
        elif "question" in kwargs and "context" in kwargs:
            if isinstance(kwargs["question"], str):
                kwargs["question"] = [kwargs["question"]]
1607

1608
1609
            if isinstance(kwargs["context"], str):
                kwargs["context"] = [kwargs["context"]]
1610

1611
1612
1613
            inputs = [
                QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"])
            ]
1614
        else:
1615
            raise ValueError("Unknown arguments {}".format(kwargs))
1616
1617
1618
1619
1620
1621
1622

        if not isinstance(inputs, list):
            inputs = [inputs]

        return inputs


Sylvain Gugger's avatar
Sylvain Gugger committed
1623
@add_end_docstrings(PIPELINE_INIT_ARGS)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1624
1625
class QuestionAnsweringPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1626
1627
    Question Answering pipeline using any :obj:`ModelForQuestionAnswering`. See the
    `question answering examples <../task_summary.html#question-answering>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1628

Sylvain Gugger's avatar
Sylvain Gugger committed
1629
1630
    This question answering pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"question-answering"`.
Lysandre Debut's avatar
Lysandre Debut committed
1631
1632

    The models that this pipeline can use are models that have been fine-tuned on a question answering task.
1633
1634
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=question-answering>`__.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1635
1636
    """

1637
1638
1639
1640
    default_input_names = "question,context"

    def __init__(
        self,
1641
1642
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1643
        modelcard: Optional[ModelCard] = None,
1644
1645
        framework: Optional[str] = None,
        device: int = -1,
1646
        task: str = "",
1647
1648
1649
1650
1651
1652
1653
1654
1655
        **kwargs
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=QuestionAnsweringArgumentHandler(),
            device=device,
1656
            task=task,
1657
            **kwargs,
1658
        )
thomwolf's avatar
thomwolf committed
1659

1660
1661
1662
1663
        self.check_model_type(
            TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
        )

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1664
    @staticmethod
1665
1666
1667
    def create_sample(
        question: Union[str, List[str]], context: Union[str, List[str]]
    ) -> Union[SquadExample, List[SquadExample]]:
1668
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1669
1670
1671
1672
        QuestionAnsweringPipeline leverages the :class:`~transformers.SquadExample` internally.
        This helper method encapsulate all the logic for converting question(s) and context(s) to
        :class:`~transformers.SquadExample`.

1673
        We currently support extractive question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
1674

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1675
        Arguments:
Sylvain Gugger's avatar
Sylvain Gugger committed
1676
1677
            question (:obj:`str` or :obj:`List[str]`): The question(s) asked.
            context (:obj:`str` or :obj:`List[str]`): The context(s) in which we will look for the answer.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1678
1679

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1680
1681
            One or a list of :class:`~transformers.SquadExample`: The corresponding
            :class:`~transformers.SquadExample` grouping question and context.
1682
1683
        """
        if isinstance(question, list):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1684
1685
1686
1687
            return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
        else:
            return SquadExample(None, question, context, None, None, None)

1688
    def __call__(self, *args, **kwargs):
1689
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1690
1691
        Answer the question(s) given as inputs by using the context(s).

1692
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
            args (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`):
                One or several :class:`~transformers.SquadExample` containing the question and context.
            X (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
                One or several :class:`~transformers.SquadExample` containing the question and context
                (will be treated the same way as if passed as the first positional argument).
            data (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
                One or several :class:`~transformers.SquadExample` containing the question and context
                (will be treated the same way as if passed as the first positional argument).
            question (:obj:`str` or :obj:`List[str]`):
                One or several question(s) (must be used in conjunction with the :obj:`context` argument).
            context (:obj:`str` or :obj:`List[str]`):
                One or several context(s) associated with the qustion(s) (must be used in conjunction with the
                :obj:`question` argument).
            topk (:obj:`int`, `optional`, defaults to 1):
                The number of answers to return (will be chosen by order of likelihood).
            doc_stride (:obj:`int`, `optional`, defaults to 128):
                If the context is too long to fit with the question for the model, it will be split in several chunks
                with some overlap. This argument controls the size of that overlap.
            max_answer_len (:obj:`int`, `optional`, defaults to 15):
                The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
            max_seq_len (:obj:`int`, `optional`, defaults to 384):
                The maximum length of the total sentence (context + question) after tokenization. The context will be
                split in several chunks (using :obj:`doc_stride`) if needed.
            max_question_len (:obj:`int`, `optional`, defaults to 64):
                The maximum length of the question after tokenization. It will be truncated if needed.
            handle_impossible_answer (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not we accept impossible as an answer.

        Return:
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:

            - **score** (:obj:`float`) -- The probability associated to the answer.
            - **start** (:obj:`int`) -- The start index of the answer (in the tokenized version of the input).
            - **end** (:obj:`int`) -- The end index of the answer (in the tokenized version of the input).
            - **answer** (:obj:`str`) -- The answer to the question.
1729
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1730
        # Set defaults values
1731
1732
1733
1734
1735
        kwargs.setdefault("topk", 1)
        kwargs.setdefault("doc_stride", 128)
        kwargs.setdefault("max_answer_len", 15)
        kwargs.setdefault("max_seq_len", 384)
        kwargs.setdefault("max_question_len", 64)
1736
        kwargs.setdefault("handle_impossible_answer", False)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1737

1738
1739
        if kwargs["topk"] < 1:
            raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1740

1741
1742
        if kwargs["max_answer_len"] < 1:
            raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1743
1744

        # Convert inputs to features
1745
        examples = self._args_parser(*args, **kwargs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1746
1747
        features_list = [
            squad_convert_examples_to_features(
1748
1749
1750
1751
1752
                examples=[example],
                tokenizer=self.tokenizer,
                max_seq_length=kwargs["max_seq_len"],
                doc_stride=kwargs["doc_stride"],
                max_query_length=kwargs["max_question_len"],
1753
                padding_strategy=PaddingStrategy.MAX_LENGTH.value,
1754
                is_training=False,
1755
                tqdm_enabled=False,
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1756
1757
1758
            )
            for example in examples
        ]
Rishabh Manoj's avatar
Rishabh Manoj committed
1759
1760
        all_answers = []
        for features, example in zip(features_list, examples):
Patrick von Platen's avatar
Patrick von Platen committed
1761
1762
            model_input_names = self.tokenizer.model_input_names + ["input_ids"]
            fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
Rishabh Manoj's avatar
Rishabh Manoj committed
1763
1764
1765
1766
1767

            # Manage tensor allocation on correct device
            with self.device_placement():
                if self.framework == "tf":
                    fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1768
                    start, end = self.model(fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1769
1770
1771
1772
1773
                    start, end = start.numpy(), end.numpy()
                else:
                    with torch.no_grad():
                        # Retrieve the score for the context tokens only (removing question tokens)
                        fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1774
                        start, end = self.model(**fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1775
1776
                        start, end = start.cpu().numpy(), end.cpu().numpy()

1777
            min_null_score = 1000000  # large and positive
Rishabh Manoj's avatar
Rishabh Manoj committed
1778
1779
            answers = []
            for (feature, start_, end_) in zip(features, start, end):
1780
1781
                # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
                undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask
Rishabh Manoj's avatar
Rishabh Manoj committed
1782

1783
1784
1785
1786
1787
1788
                # Generate mask
                undesired_tokens_mask = undesired_tokens == 0.0

                # Make sure non-context indexes in the tensor cannot contribute to the softmax
                start_ = np.where(undesired_tokens_mask, -10000.0, start_)
                end_ = np.where(undesired_tokens_mask, -10000.0, end_)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1789
1790
1791
1792
1793

                # Normalize logits and spans to retrieve the answer
                start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
                end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

1794
1795
1796
                if kwargs["handle_impossible_answer"]:
                    min_null_score = min(min_null_score, (start_[0] * end_[0]).item())

1797
1798
1799
                # Mask CLS
                start_[0] = end_[0] = 0.0

Rishabh Manoj's avatar
Rishabh Manoj committed
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
                starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
                char_to_word = np.array(example.char_to_word_offset)

                # Convert the answer (tokens) back to the original text
                answers += [
                    {
                        "score": score.item(),
                        "start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
                        "end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
                        "answer": " ".join(
                            example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
                        ),
                    }
                    for s, e, score in zip(starts, ends, scores)
                ]
1815
1816
1817
1818

            if kwargs["handle_impossible_answer"]:
                answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1819
1820
1821
            answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
            all_answers += answers

Rishabh Manoj's avatar
Rishabh Manoj committed
1822
        if len(all_answers) == 1:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1823
            return all_answers[0]
Rishabh Manoj's avatar
Rishabh Manoj committed
1824
        return all_answers
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1825
1826

    def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
1827
        """
Tiger's avatar
Tiger committed
1828
        Take the output of any :obj:`ModelForQuestionAnswering` and will generate probabilities for each span to be
1829
        the actual answer.
Sylvain Gugger's avatar
Sylvain Gugger committed
1830

1831
1832
1833
1834
1835
        In addition, it filters out some unwanted/impossible cases like answer len being greater than
        max_answer_len or answer end position being before the starting position.
        The method supports output the k-best answer through the topk argument.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1836
1837
1838
1839
            start (:obj:`np.ndarray`): Individual start probabilities for each token.
            end (:obj:`np.ndarray`): Individual end probabilities for each token.
            topk (:obj:`int`): Indicates how many possible answer span(s) to extract from the model output.
            max_answer_len (:obj:`int`): Maximum size of the answer to extract from the model's output.
1840
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
        # Ensure we have batch axis
        if start.ndim == 1:
            start = start[None]

        if end.ndim == 1:
            end = end[None]

        # Compute the score of each tuple(start, end) to be the real answer
        outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))

        # Remove candidate with end < start and end - start > max_answer_len
        candidates = np.tril(np.triu(outer), max_answer_len - 1)

        #  Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
        scores_flat = candidates.flatten()
        if topk == 1:
            idx_sort = [np.argmax(scores_flat)]
        elif len(scores_flat) < topk:
            idx_sort = np.argsort(-scores_flat)
        else:
            idx = np.argpartition(-scores_flat, topk)[0:topk]
            idx_sort = idx[np.argsort(-scores_flat[idx])]

        start, end = np.unravel_index(idx_sort, candidates.shape)[1:]
        return start, end, candidates[0, start, end]

Sylvain Gugger's avatar
Sylvain Gugger committed
1867
    def span_to_answer(self, text: str, start: int, end: int) -> Dict[str, Union[str, int]]:
1868
        """
Tiger's avatar
Tiger committed
1869
        When decoding from token probabilities, this method maps token indexes to actual word in
1870
1871
1872
        the initial context.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1873
1874
1875
            text (:obj:`str`): The actual context to extract the answer from.
            start (:obj:`int`): The answer starting token index.
            end (:obj:`int`): The answer end token index.
1876
1877

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1878
            Dictionary like :obj:`{'answer': str, 'start': int, 'end': int}`
1879
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
        words = []
        token_idx = char_start_idx = char_end_idx = chars_idx = 0

        for i, word in enumerate(text.split(" ")):
            token = self.tokenizer.tokenize(word)

            # Append words if they are in the span
            if start <= token_idx <= end:
                if token_idx == start:
                    char_start_idx = chars_idx

                if token_idx == end:
                    char_end_idx = chars_idx + len(word)

                words += [word]

            # Stop if we went over the end of the answer
            if token_idx > end:
                break

            # Append the subtokenization length to the running index
            token_idx += len(token)
            chars_idx += len(word) + 1

        # Join text with spaces
1905
1906
1907
1908
1909
        return {
            "answer": " ".join(words),
            "start": max(0, char_start_idx),
            "end": min(len(text), char_end_idx),
        }
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1910
1911


Sylvain Gugger's avatar
Sylvain Gugger committed
1912
@add_end_docstrings(PIPELINE_INIT_ARGS)
1913
1914
class SummarizationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1915
1916
1917
1918
1919
1920
1921
1922
1923
    Summarize news articles and other documents.

    This summarizing pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"summarization"`.

    The models that this pipeline can use are models that have been fine-tuned on a summarization task,
    which is currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'.
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=summarization>`__.
1924
1925
1926

    Usage::

1927
        # use bart in pytorch
1928
        summarizer = pipeline("summarization")
1929
1930
1931
1932
1933
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)

        # use t5 in tf
        summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
1934
1935
    """

1936
    def __init__(self, *args, **kwargs):
1937
        kwargs.update(task="summarization")
1938
1939
1940
1941
1942
        super().__init__(*args, **kwargs)

        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
1943

1944
    def __call__(
1945
        self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
1946
1947
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1948
        Summarize the text(s) given as inputs.
1949

Sylvain Gugger's avatar
Sylvain Gugger committed
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
        Args:
            documents (`str` or :obj:`List[str]`):
                One or several articles (or one list of articles) to summarize.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
1962

Sylvain Gugger's avatar
Sylvain Gugger committed
1963
1964
1965
        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
1966

Sylvain Gugger's avatar
Sylvain Gugger committed
1967
1968
1969
1970
            - **summary_text** (:obj:`str`, present when ``return_text=True``) -- The summary of the corresponding
              input.
            - **summary_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the summary.
1971
1972
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
        assert len(documents) > 0, "Please provide a document to summarize"

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

        if isinstance(documents[0], list):
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"

            documents = ([prefix + document for document in documents[0]],)
1983
            padding = True
1984
1985
1986

        elif isinstance(documents[0], str):
            documents = (prefix + documents[0],)
1987
            padding = False
1988
1989
1990
1991
1992
1993
1994
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
                    documents[0]
                )
            )

1995
        with self.device_placement():
1996
            inputs = self._parse_and_tokenize(*documents, padding=padding)
1997
1998
1999
2000
2001

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]
            elif self.framework == "tf":
2002
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()
2003

2004
2005
            min_length = generate_kwargs.get("min_length", self.model.config.min_length)
            if input_length < min_length // 2:
2006
                logger.warning(
2007
                    "Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
2008
                        min_length, input_length
2009
2010
2011
                    )
                )

2012
2013
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length < max_length:
2014
                logger.warning(
2015
                    "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
2016
                        max_length, input_length
2017
2018
2019
                    )
                )

2020
            summaries = self.model.generate(
Lysandre's avatar
Lysandre committed
2021
2022
2023
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
2024
            )
2025

2026
2027
2028
2029
2030
2031
2032
            results = []
            for summary in summaries:
                record = {}
                if return_tensors:
                    record["summary_token_ids"] = summary
                if return_text:
                    record["summary_text"] = self.tokenizer.decode(
Lysandre's avatar
Lysandre committed
2033
2034
2035
                        summary,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
2036
2037
2038
2039
2040
                    )
                results.append(record)
            return results


Sylvain Gugger's avatar
Sylvain Gugger committed
2041
@add_end_docstrings(PIPELINE_INIT_ARGS)
2042
2043
2044
2045
class TranslationPipeline(Pipeline):
    """
    Translates from one language to another.

Sylvain Gugger's avatar
Sylvain Gugger committed
2046
2047
    This translation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"translation_xx_to_yy"`.
2048

Sylvain Gugger's avatar
Sylvain Gugger committed
2049
    The models that this pipeline can use are models that have been fine-tuned on a translation task.
2050
2051
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=translation>`__.
2052

Sylvain Gugger's avatar
Sylvain Gugger committed
2053
2054
2055
    Usage::
        en_fr_translator = pipeline("translation_en_to_fr")
        en_fr_translator("How old are you?")
2056
2057
    """

2058
2059
2060
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

2061
2062
2063
        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
2064

2065
    def __call__(
2066
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
2067
2068
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
2069
2070
        Translate the text(s) given as inputs.

2071
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
            args (:obj:`str` or :obj:`List[str]`):
                Texts to be translated.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
2083

Sylvain Gugger's avatar
Sylvain Gugger committed
2084
2085
2086
        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
2087

Sylvain Gugger's avatar
Sylvain Gugger committed
2088
2089
2090
            - **translation_text** (:obj:`str`, present when ``return_text=True``) -- The translation.
            - **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the translation.
2091
2092
2093
2094
2095
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

2096
        if isinstance(args[0], list):
2097
2098
2099
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
2100
            args = ([prefix + text for text in args[0]],)
2101
            padding = True
2102

2103
2104
        elif isinstance(args[0], str):
            args = (prefix + args[0],)
2105
            padding = False
2106
2107
2108
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
2109
                    args[0]
2110
2111
2112
2113
                )
            )

        with self.device_placement():
2114
            inputs = self._parse_and_tokenize(*args, padding=padding)
2115
2116
2117
2118
2119
2120
2121
2122

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

2123
2124
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length > 0.9 * max_length:
2125
2126
                logger.warning(
                    "Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
2127
                        input_length, max_length
2128
2129
2130
2131
                    )
                )

            translations = self.model.generate(
Lysandre's avatar
Lysandre committed
2132
2133
2134
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
            )
            results = []
            for translation in translations:
                record = {}
                if return_tensors:
                    record["translation_token_ids"] = translation
                if return_text:
                    record["translation_text"] = self.tokenizer.decode(
                        translation,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
2146
2147
2148
2149
2150
                    )
                results.append(record)
            return results


2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
@add_end_docstrings(PIPELINE_INIT_ARGS)
class Text2TextGenerationPipeline(Pipeline):
    """
    Pipeline for text to text generation using seq2seq models.

    This Text2TextGenerationPipeline pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"text2text-generation"`.

    The models that this pipeline can use are models that have been fine-tuned on a translation task.
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=seq2seq>`__.

    Usage::

        text2text_generator = pipeline("text2text-generation")
        text2text_generator("question: What is 42 ? context: 42 is the answer to life, the universe and everything")
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.check_model_type(
            TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )

    def __call__(
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
    ):
        r"""
        Generate the output text(s) using text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                Input text for the encoder.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:

            - **generated_text** (:obj:`str`, present when ``return_text=True``) -- The generated text.
            - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the generated text.
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"

        if isinstance(args[0], list):
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
            padding = True

        elif isinstance(args[0], str):
            padding = False
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
                    args[0]
                )
            )

        with self.device_placement():
            inputs = self._parse_and_tokenize(*args, padding=padding)

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)

            generations = self.model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
            )
            results = []
            for generation in generations:
                record = {}
                if return_tensors:
                    record["generated_token_ids"] = generation
                if return_text:
                    record["generated_text"] = self.tokenizer.decode(
                        generation,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                results.append(record)
            return results


2248
2249
2250
class Conversation:
    """
    Utility class containing a conversation and its history. This class is meant to be used as an input to the
Sylvain Gugger's avatar
Sylvain Gugger committed
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
    :class:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the
    addition of new user input and generated model responses. A conversation needs to contain an unprocessed user input
    before being passed to the :class:`~transformers.ConversationalPipeline`. This user input is either created when
    the class is instantiated, or by calling :obj:`conversional_pipeline.append_response("input")` after a conversation
    turn.

    Arguments:
        text (:obj:`str`, `optional`):
            The initial user input to start the conversation. If not provided, a user input needs to be provided
            manually using the :meth:`~transformers.Conversation.add_user_input` method before the conversation can
            begin.
        conversation_id (:obj:`uuid.UUID`, `optional`):
            Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
            conversation.
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289

    Usage::

        conversation = Conversation("Going to the movies tonight - any suggestions?")

        # Steps usually performed by the model when generating a response:
        # 1. Mark the user input as processed (moved to the history)
        conversation.mark_processed()
        # 2. Append a mode response
        conversation.append_response("The Big lebowski.")

        conversation.add_user_input("Is it good?")
    """

    def __init__(self, text: str = None, conversation_id: UUID = None):
        if not conversation_id:
            conversation_id = uuid.uuid4()
        self.uuid: UUID = conversation_id
        self.past_user_inputs: List[str] = []
        self.generated_responses: List[str] = []
        self.history: List[int] = []
        self.new_user_input: Optional[str] = text

    def add_user_input(self, text: str, overwrite: bool = False):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2290
2291
        Add a user input to the conversation for the next round. This populates the internal :obj:`new_user_input`
        field.
2292
2293

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2294
2295
2296
            text (:obj:`str`): The user input for the next conversation round.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not existing and unprocessed user input should be overwritten when this function is called.
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
        """
        if self.new_user_input:
            if overwrite:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format(
                        self.new_user_input, text
                    )
                )
                self.new_user_input = text
            else:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" new input ignored: "{}". '
                    "Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text)
                )
        else:
            self.new_user_input = text

    def mark_processed(self):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2316
2317
        Mark the conversation as processed (moves the content of :obj:`new_user_input` to :obj:`past_user_inputs`) and
        empties the :obj:`new_user_input` field.
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
        """
        if self.new_user_input:
            self.past_user_inputs.append(self.new_user_input)
        self.new_user_input = None

    def append_response(self, response: str):
        """
        Append a response to the list of generated responses.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2328
            response (:obj:`str`): The model generated response.
2329
2330
2331
2332
2333
        """
        self.generated_responses.append(response)

    def set_history(self, history: List[int]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2334
2335
        Updates the value of the history of the conversation. The history is represented by a list of :obj:`token_ids`.
        The history is used by the model to generate responses based on the previous conversation turns.
2336
2337

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2338
            history (:obj:`List[int]`): History of tokens provided and generated for this conversation.
2339
2340
2341
2342
2343
2344
2345
2346
        """
        self.history = history

    def __repr__(self):
        """
        Generates a string representation of the conversation.

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2347
            :obj:`str`:
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362

            Example:
            Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114
            user >> Going to the movies tonight - any suggestions?
            bot >> The Big Lebowski
        """
        output = "Conversation id: {} \n".format(self.uuid)
        for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
            output += "user >> {} \n".format(user_input)
            output += "bot >> {} \n".format(generated_response)
        if self.new_user_input is not None:
            output += "user >> {} \n".format(self.new_user_input)
        return output


Sylvain Gugger's avatar
Sylvain Gugger committed
2363
2364
2365
2366
2367
2368
2369
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        min_length_for_response (:obj:`int`, `optional`, defaults to 32):
            The minimum length (in number of tokens) for a response.
    """,
)
2370
2371
2372
2373
class ConversationalPipeline(Pipeline):
    """
    Multi-turn conversational pipeline.

Sylvain Gugger's avatar
Sylvain Gugger committed
2374
2375
2376
2377
2378
2379
2380
2381
    This conversational pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"conversational"`.

    The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task,
    currently: `'microsoft/DialoGPT-small'`, `'microsoft/DialoGPT-medium'`, `'microsoft/DialoGPT-large'`.
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=conversational>`__.

2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
    Usage::

        conversational_pipeline = pipeline("conversational")

        conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
        conversation_2 = Conversation("What's the last book you have read?")

        conversational_pipeline([conversation_1, conversation_2])

        conversation_1.add_user_input("Is it an action movie?")
        conversation_2.add_user_input("What is the genre of this book?")

        conversational_pipeline([conversation_1, conversation_2])
    """

    def __init__(self, min_length_for_response=32, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
        if self.tokenizer.pad_token_id is not None:
            self.pad_token_id = self.tokenizer.pad_token_id
        else:
            self.pad_token_id = self.tokenizer.eos_token_id
        self.min_length_for_response = min_length_for_response

    def __call__(
        self,
        conversations: Union[Conversation, List[Conversation]],
        clean_up_tokenization_spaces=True,
        **generate_kwargs
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
2413
2414
        Generate responses for the conversation(s) given as inputs.

2415
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2416
2417
2418
2419
2420
2421
2422
            conversations (a :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`):
                Conversations to generate responses for.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
2423
2424

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2425
2426
            :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`: Conversation(s) with
            updated generated responses for those containing a new user input.
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
        """

        # Input validation
        if isinstance(conversations, list):
            for conversation in conversations:
                assert isinstance(
                    conversation, Conversation
                ), "DialoguePipeline expects a Conversation or list of Conversations as an input"
                if conversation.new_user_input is None:
                    raise ValueError(
                        "Conversation with UUID {} does not contain new user input to process. "
                        "Add user inputs with the conversation's `add_user_input` method".format(
                            type(conversation.uuid)
                        )
                    )
            assert (
                self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
        elif isinstance(conversations, Conversation):
            conversations = [conversations]
        else:
            raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")

        with self.device_placement():

            inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
            histories = [conversation.history for conversation in conversations]
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            inputs = self._concat_inputs_history(inputs, histories, max_length)

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

            if input_length > 0.9 * max_length:
                logger.warning(
                    "Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
                    "You might consider trimming the early phase of the conversation".format(input_length, max_length)
                )
            generated_responses = self.model.generate(
Lysandre's avatar
Lysandre committed
2470
2471
2472
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
            )

            cleaned_history = self._clean_padding_history(generated_responses)
            output = []
            for conversation_index, conversation in enumerate(conversations):
                conversation.mark_processed()
                conversation.generated_responses.append(
                    self.tokenizer.decode(
                        cleaned_history[conversation_index][input_length:],
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                )
                conversation.set_history(cleaned_history[conversation_index])
                output.append(conversation)
            if len(output) == 1:
                return output[0]
            else:
                return output

    def _parse_and_tokenize(self, *args, **kwargs):
        """
        Parse arguments and tokenize, adding an EOS token at the end of the user input
        """
        # Parse arguments
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
        for input in inputs:
            input.append(self.tokenizer.eos_token_id)
        return inputs

    def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
        """
        Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
        an input:
            - at the end of the concatenated history and new user input, so that all input to the model have the same
                length
            - at the end of the generated response, as some responses will be longer than others
        This method cleans up these padding token so that the history for each conversation is not impacted by the
        batching process.
        """
        outputs = []
        for sequence in generated_tensor:
            sequence_tokens = []
            is_previous_pad = False
            for token in sequence:
                if token == self.pad_token_id:
                    if is_previous_pad:
                        continue
                    else:
                        is_previous_pad = True
                else:
                    is_previous_pad = False
                if self.framework == "pt":
                    sequence_tokens.append(token.item())
                else:
                    sequence_tokens.append(int(token.numpy()))

            outputs.append(sequence_tokens)
        return outputs

    def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
        """
        Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
        """
        outputs = []
        for new_input, history in zip(inputs, histories):
            if history is not None:
                new_input = history + new_input
            if len(new_input) > max_length - self.min_length_for_response:
                cutoff_eos_index = 0
                while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
                    if cutoff_eos_index >= len(new_input):
                        break
                    cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
                    if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
                        break
                    else:
                        new_input = new_input[cutoff_eos_index + 1 :]
            outputs.append(new_input)
        max_len = max([len(item) for item in outputs])
        outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
        outputs = BatchEncoding(
Lysandre's avatar
Lysandre committed
2556
2557
            {"input_ids": outputs, "attention_mask": [[1] * len(outputs)]},
            tensor_type=self.framework,
2558
2559
2560
2561
        )
        return outputs


2562
# Register all the supported tasks here
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2563
SUPPORTED_TASKS = {
2564
2565
2566
2567
    "feature-extraction": {
        "impl": FeatureExtractionPipeline,
        "tf": TFAutoModel if is_tf_available() else None,
        "pt": AutoModel if is_torch_available() else None,
2568
        "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
2569
    },
2570
2571
2572
2573
2574
2575
    "sentiment-analysis": {
        "impl": TextClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {
2576
2577
                "pt": "distilbert-base-uncased-finetuned-sst-2-english",
                "tf": "distilbert-base-uncased-finetuned-sst-2-english",
2578
            },
2579
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2580
    },
2581
    "ner": {
2582
        "impl": TokenClassificationPipeline,
2583
2584
2585
2586
        "tf": TFAutoModelForTokenClassification if is_tf_available() else None,
        "pt": AutoModelForTokenClassification if is_torch_available() else None,
        "default": {
            "model": {
Julien Chaumond's avatar
Julien Chaumond committed
2587
2588
                "pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
                "tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
2589
            },
2590
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2591
    },
2592
2593
2594
2595
2596
    "question-answering": {
        "impl": QuestionAnsweringPipeline,
        "tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
        "pt": AutoModelForQuestionAnswering if is_torch_available() else None,
        "default": {
Lysandre's avatar
E231  
Lysandre committed
2597
            "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
2598
2599
        },
    },
Julien Chaumond's avatar
Julien Chaumond committed
2600
2601
    "fill-mask": {
        "impl": FillMaskPipeline,
2602
        "tf": TFAutoModelForMaskedLM if is_tf_available() else None,
2603
        "pt": AutoModelForMaskedLM if is_torch_available() else None,
2604
        "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
Julien Chaumond's avatar
Julien Chaumond committed
2605
    },
2606
2607
    "summarization": {
        "impl": SummarizationPipeline,
2608
        "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
2609
2610
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
        "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
2611
    },
2612
2613
    # This task is a special case as it's parametrized by SRC, TGT languages.
    "translation": {
2614
        "impl": TranslationPipeline,
2615
        "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
2616
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2617
2618
2619
2620
2621
        "default": {
            ("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
            ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
            ("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
        },
2622
    },
2623
2624
2625
2626
2627
2628
    "text2text-generation": {
        "impl": Text2TextGenerationPipeline,
        "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
    },
2629
2630
    "text-generation": {
        "impl": TextGenerationPipeline,
2631
        "tf": TFAutoModelForCausalLM if is_tf_available() else None,
2632
        "pt": AutoModelForCausalLM if is_torch_available() else None,
2633
        "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
2634
    },
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
    "zero-shot-classification": {
        "impl": ZeroShotClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
        },
    },
2645
2646
2647
2648
2649
2650
    "conversational": {
        "impl": ConversationalPipeline,
        "tf": TFAutoModelForCausalLM if is_tf_available() else None,
        "pt": AutoModelForCausalLM if is_torch_available() else None,
        "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
    },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2651
2652
2653
}


2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
def check_task(task: str) -> Tuple[Dict, Any]:
    """
    Checks an incoming task string, to validate it's correct and return the
    default Pipeline and Model classes, and default models if they exist.

    Args:
        task (:obj:`str`):
            The task defining which pipeline will be returned. Currently accepted tasks are:

            - :obj:`"feature-extraction"`
            - :obj:`"sentiment-analysis"`
            - :obj:`"ner"`
            - :obj:`"question-answering"`
            - :obj:`"fill-mask"`
            - :obj:`"summarization"`
            - :obj:`"translation_xx_to_yy"`
            - :obj:`"translation"`
            - :obj:`"text-generation"`
            - :obj:`"conversational"`

    Returns:
        (task_defaults:obj:`dict`, task_options: (:obj:`tuple`, None))
            The actual dictionnary required to initialize the pipeline and some
            extra task options for parametrized tasks like "translation_XX_to_YY"


    """
    if task in SUPPORTED_TASKS:
        targeted_task = SUPPORTED_TASKS[task]
        return targeted_task, None

    if task.startswith("translation"):
        tokens = task.split("_")
        if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
            targeted_task = SUPPORTED_TASKS["translation"]
            return targeted_task, (tokens[1], tokens[3])
        raise KeyError("Invalid translation task {}, use 'translation_XX_to_YY' format".format(task))

    raise KeyError(
        "Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()) + ["translation_XX_to_YY"])
    )


2697
2698
2699
2700
2701
def pipeline(
    task: str,
    model: Optional = None,
    config: Optional[Union[str, PretrainedConfig]] = None,
    tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
2702
    framework: Optional[str] = None,
2703
2704
    **kwargs
) -> Pipeline:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2705
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
2706
    Utility factory method to build a :class:`~transformers.Pipeline`.
Lysandre Debut's avatar
Lysandre Debut committed
2707

Sylvain Gugger's avatar
Sylvain Gugger committed
2708
    Pipelines are made of:
Lysandre Debut's avatar
Lysandre Debut committed
2709

Sylvain Gugger's avatar
Sylvain Gugger committed
2710
2711
2712
        - A :doc:`tokenizer <tokenizer>` in charge of mapping raw textual input to token.
        - A :doc:`model <model>` to make predictions from the inputs.
        - Some (optional) post processing for enhancing model's output.
Lysandre Debut's avatar
Lysandre Debut committed
2713
2714
2715
2716
2717

    Args:
        task (:obj:`str`):
            The task defining which pipeline will be returned. Currently accepted tasks are:

Sylvain Gugger's avatar
Sylvain Gugger committed
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
            - :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
            - :obj:`"sentiment-analysis"`: will return a :class:`~transformers.TextClassificationPipeline`.
            - :obj:`"ner"`: will return a :class:`~transformers.TokenClassificationPipeline`.
            - :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
            - :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
            - :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.
            - :obj:`"translation_xx_to_yy"`: will return a :class:`~transformers.TranslationPipeline`.
            - :obj:`"text-generation"`: will return a :class:`~transformers.TextGenerationPipeline`.
            - :obj:`"conversation"`: will return a :class:`~transformers.ConversationalPipeline`.
        model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`):
            The model that will be used by the pipeline to make predictions. This can be a model identifier or an
            actual instance of a pretrained model inheriting from :class:`~transformers.PreTrainedModel` (for PyTorch)
            or :class:`~transformers.TFPreTrainedModel` (for TensorFlow).

            If not provided, the default for the :obj:`task` will be loaded.
        config (:obj:`str` or :obj:`~transformers.PretrainedConfig`, `optional`):
            The configuration that will be used by the pipeline to instantiate the model. This can be a model
            identifier or an actual pretrained model configuration inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
2736
2737
            :class:`~transformers.PretrainedConfig`.

Sylvain Gugger's avatar
Sylvain Gugger committed
2738
2739
2740
2741
            If not provided, the default for the :obj:`task` will be loaded.
        tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`):
            The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
            identifier or an actual pretrained tokenizer inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
2742
2743
            :class:`~transformers.PreTrainedTokenizer`.

Sylvain Gugger's avatar
Sylvain Gugger committed
2744
2745
2746
2747
            If not provided, the default for the :obj:`task` will be loaded.
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
2748
2749

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
2750
2751
2752
2753
2754
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        kwargs:
            Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
            corresponding pipeline class for possible values).
Lysandre Debut's avatar
Lysandre Debut committed
2755
2756

    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2757
        :class:`~transformers.Pipeline`: A suitable pipeline for the task.
Lysandre Debut's avatar
Lysandre Debut committed
2758
2759
2760

    Examples::

2761
        >>> from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
Lysandre Debut's avatar
Lysandre Debut committed
2762

2763
2764
        >>> # Sentiment analysis pipeline
        >>> pipeline('sentiment-analysis')
Lysandre Debut's avatar
Lysandre Debut committed
2765

2766
2767
        >>> # Question answering pipeline, specifying the checkpoint identifier
        >>> pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased')
Lysandre Debut's avatar
Lysandre Debut committed
2768

2769
2770
2771
2772
        >>> # Named entity recognition pipeline, passing in a specific model and tokenizer
        >>> model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
        >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        >>> pipeline('ner', model=model, tokenizer=tokenizer)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2773
    """
2774
    # Retrieve the task
2775
    targeted_task, task_options = check_task(task)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2776

2777
    # Use default model/config/tokenizer for the task if no model is provided
2778
    if model is None:
2779
        # At that point framework might still be undetermined
2780
        model = get_default_model(targeted_task, framework, task_options)
2781
2782
2783
2784

    framework = framework or get_framework(model)

    task_class, model_class = targeted_task["impl"], targeted_task[framework]
2785

2786
2787
    # Try to infer tokenizer from model or config name (if provided as str)
    if tokenizer is None:
2788
        if isinstance(model, str):
2789
            tokenizer = model
2790
        elif isinstance(config, str):
2791
2792
2793
            tokenizer = config
        else:
            # Impossible to guest what is the right tokenizer here
2794
2795
            raise Exception(
                "Impossible to guess which tokenizer to use. "
2796
                "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
2797
            )
2798

Lysandre Debut's avatar
Lysandre Debut committed
2799
    modelcard = None
2800
    # Try to infer modelcard from model or config name (if provided as str)
Lysandre Debut's avatar
Lysandre Debut committed
2801
2802
2803
2804
    if isinstance(model, str):
        modelcard = model
    elif isinstance(config, str):
        modelcard = config
2805
2806

    # Instantiate tokenizer if needed
2807
2808
2809
2810
2811
2812
    if isinstance(tokenizer, (str, tuple)):
        if isinstance(tokenizer, tuple):
            # For tuple we have (tokenizer name, {kwargs})
            tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
        else:
            tokenizer = AutoTokenizer.from_pretrained(tokenizer)
2813
2814
2815
2816

    # Instantiate config if needed
    if isinstance(config, str):
        config = AutoConfig.from_pretrained(config)
2817

thomwolf's avatar
thomwolf committed
2818
2819
2820
2821
    # Instantiate modelcard if needed
    if isinstance(modelcard, str):
        modelcard = ModelCard.from_pretrained(modelcard)

2822
    # Instantiate model if needed
2823
    if isinstance(model, str):
2824
2825
        # Handle transparent TF/PT model conversion
        model_kwargs = {}
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
        if framework == "pt" and model.endswith(".h5"):
            model_kwargs["from_tf"] = True
            logger.warning(
                "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
                "Trying to load the model with PyTorch."
            )
        elif framework == "tf" and model.endswith(".bin"):
            model_kwargs["from_pt"] = True
            logger.warning(
                "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
                "Trying to load the model with Tensorflow."
            )
2838
        model = model_class.from_pretrained(model, config=config, **model_kwargs)
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
        if task == "translation" and model.config.task_specific_params:
            for key in model.config.task_specific_params:
                if key.startswith("translation"):
                    task = key
                    warnings.warn(
                        '"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{}"'.format(
                            task
                        ),
                        UserWarning,
                    )
                    break
2850

2851
    return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)