"pytorch_transformers/preprocessing/glue.py" did not exist on "bac332fec06c843c8c3436091bb53343c109202d"
pipelines.py 108 KB
Newer Older
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

Morgan Funtowicz's avatar
Morgan Funtowicz committed
16

17
18
import csv
import json
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import logging
Morgan Funtowicz's avatar
Morgan Funtowicz committed
20
import os
21
import pickle
Aymeric Augustin's avatar
Aymeric Augustin committed
22
import sys
23
import uuid
Morgan Funtowicz's avatar
Morgan Funtowicz committed
24
from abc import ABC, abstractmethod
25
from contextlib import contextmanager
26
from itertools import chain
27
from os.path import abspath, exists
28
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
29
from uuid import UUID
Morgan Funtowicz's avatar
Morgan Funtowicz committed
30
31
32

import numpy as np

33
from .configuration_auto import AutoConfig
34
35
from .configuration_utils import PretrainedConfig
from .data import SquadExample, squad_convert_examples_to_features
Sylvain Gugger's avatar
Sylvain Gugger committed
36
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available
37
38
39
40
from .modelcard import ModelCard
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
41
from .tokenization_utils_base import BatchEncoding, PaddingStrategy
Morgan Funtowicz's avatar
Morgan Funtowicz committed
42

Aymeric Augustin's avatar
Aymeric Augustin committed
43

Morgan Funtowicz's avatar
Morgan Funtowicz committed
44
if is_tf_available():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
45
    import tensorflow as tf
46
    from .modeling_tf_auto import (
47
48
49
50
        TFAutoModel,
        TFAutoModelForSequenceClassification,
        TFAutoModelForQuestionAnswering,
        TFAutoModelForTokenClassification,
Julien Chaumond's avatar
Julien Chaumond committed
51
        TFAutoModelWithLMHead,
52
53
54
55
        TF_MODEL_WITH_LM_HEAD_MAPPING,
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
56
        TFAutoModelForCausalLM,
57
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
58
59
60

if is_torch_available():
    import torch
61
    from .modeling_auto import (
62
63
64
65
        AutoModel,
        AutoModelForSequenceClassification,
        AutoModelForQuestionAnswering,
        AutoModelForTokenClassification,
66
        AutoModelForSeq2SeqLM,
67
68
        AutoModelForCausalLM,
        AutoModelForMaskedLM,
69
70
71
72
        MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
        MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
73
        MODEL_FOR_MASKED_LM_MAPPING,
74
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
75

76
77
78
79
if TYPE_CHECKING:
    from .modeling_utils import PreTrainedModel
    from .modeling_tf_utils import TFPreTrainedModel

Morgan Funtowicz's avatar
Morgan Funtowicz committed
80

81
82
logger = logging.getLogger(__name__)

83

thomwolf's avatar
thomwolf committed
84
def get_framework(model=None):
Sylvain Gugger's avatar
Sylvain Gugger committed
85
86
87
88
89
90
91
    """
    Select framework (TensorFlow or PyTorch) to use.

    Args:
        model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`, `optional`):
            If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
            the model name). If no specific model is provided, defaults to using PyTorch.
92
    """
thomwolf's avatar
thomwolf committed
93
    if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
Julien Chaumond's avatar
Julien Chaumond committed
94
        # Both framework are available but the user supplied a model class instance.
thomwolf's avatar
thomwolf committed
95
        # Try to guess which framework to use from the model classname
96
        framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
97
    elif not is_tf_available() and not is_torch_available():
Aymeric Augustin's avatar
Aymeric Augustin committed
98
        raise RuntimeError(
99
100
101
102
            "At least one of TensorFlow 2.0 or PyTorch should be installed. "
            "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
            "To install PyTorch, read the instructions at https://pytorch.org/."
        )
103
    else:
104
        # framework = 'tf' if is_tf_available() else 'pt'
105
        framework = "pt" if is_torch_available() else "tf"
thomwolf's avatar
thomwolf committed
106
107
    return framework

108

109
110
class PipelineException(Exception):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
111
112
113
114
115
116
    Raised by a :class:`~transformers.Pipeline` when handling __call__.

    Args:
        task (:obj:`str`): The task of the pipeline.
        model (:obj:`str`): The model used by the pipeline.
        reason (:obj:`str`): The error message to display.
117
118
119
120
121
122
123
124
125
    """

    def __init__(self, task: str, model: str, reason: str):
        super().__init__(reason)

        self.task = task
        self.model = model


126
127
class ArgumentHandler(ABC):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
128
    Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
129
    """
130

131
132
133
    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplementedError()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
134
135


136
137
class DefaultArgumentHandler(ArgumentHandler):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
138
    Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`.
139
    """
140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    @staticmethod
    def handle_kwargs(kwargs: Dict) -> List:
        if len(kwargs) == 1:
            output = list(kwargs.values())
        else:
            output = list(chain(kwargs.values()))

        return DefaultArgumentHandler.handle_args(output)

    @staticmethod
    def handle_args(args: Sequence[Any]) -> List[str]:

        # Only one argument, let's do case by case
        if len(args) == 1:
            if isinstance(args[0], str):
156
                return [args[0]]
157
158
159
160
161
162
            elif not isinstance(args[0], list):
                return list(args)
            else:
                return args[0]

        # Multiple arguments (x1, x2, ...)
163
        elif len(args) > 1:
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            if all([isinstance(arg, str) for arg in args]):
                return list(args)

            # If not instance of list, then it should instance of iterable
            elif isinstance(args, Iterable):
                return list(chain.from_iterable(chain(args)))
            else:
                raise ValueError(
                    "Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
                )
        else:
            return []

    def __call__(self, *args, **kwargs):
        if len(kwargs) > 0 and len(args) > 0:
            raise ValueError("Pipeline cannot handle mixed args and kwargs")

        if len(kwargs) > 0:
            return DefaultArgumentHandler.handle_kwargs(kwargs)
        else:
            return DefaultArgumentHandler.handle_args(args)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
185
186


187
class PipelineDataFormat:
188
189
190
    """
    Base class for all the pipeline supported data format both for reading and writing.
    Supported data formats currently includes:
Sylvain Gugger's avatar
Sylvain Gugger committed
191
192
193
    - JSON
    - CSV
    - stdin/stdout (pipe)
194

Sylvain Gugger's avatar
Sylvain Gugger committed
195
196
197
198
199
200
201
202
203
    :obj:`PipelineDataFormat` also includes some utilities to work with multi-columns like mapping from datasets
    columns to pipelines keyword arguments through the :obj:`dataset_kwarg_1=dataset_column_1` format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
204
    """
205
206

    SUPPORTED_FORMATS = ["json", "csv", "pipe"]
207

208
    def __init__(
Sylvain Gugger's avatar
Sylvain Gugger committed
209
        self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite: bool = False,
210
    ):
thomwolf's avatar
thomwolf committed
211
212
        self.output_path = output_path
        self.input_path = input_path
213
        self.column = column.split(",") if column is not None else [""]
214
215
216
        self.is_multi_columns = len(self.column) > 1

        if self.is_multi_columns:
217
            self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column]
218

thomwolf's avatar
thomwolf committed
219
        if output_path is not None and not overwrite:
thomwolf's avatar
thomwolf committed
220
            if exists(abspath(self.output_path)):
221
                raise OSError("{} already exists on disk".format(self.output_path))
222

thomwolf's avatar
thomwolf committed
223
224
        if input_path is not None:
            if not exists(abspath(self.input_path)):
225
                raise OSError("{} doesnt exist on disk".format(self.input_path))
226
227
228
229
230
231

    @abstractmethod
    def __iter__(self):
        raise NotImplementedError()

    @abstractmethod
Sylvain Gugger's avatar
Sylvain Gugger committed
232
    def save(self, data: Union[dict, List[dict]]):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
233
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
234
235
236
237
238
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
239
        """
240
241
        raise NotImplementedError()

242
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
243
244
        """
        Save the provided data object as a pickle-formatted binary data on the disk.
Sylvain Gugger's avatar
Sylvain Gugger committed
245
246
247
248
249
250

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.

        Returns:
            :obj:`str`: Path where the data has been saved.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
251
        """
thomwolf's avatar
thomwolf committed
252
        path, _ = os.path.splitext(self.output_path)
253
        binary_path = os.path.extsep.join((path, "pickle"))
254

255
        with open(binary_path, "wb+") as f_output:
256
257
258
259
            pickle.dump(data, f_output)

        return binary_path

260
    @staticmethod
261
    def from_str(
262
        format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
Sylvain Gugger's avatar
Sylvain Gugger committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    ) -> "PipelineDataFormat":
        """
        Creates an instance of the right subclass of :class:`~transformers.pipelines.PipelineDataFormat` depending
        on :obj:`format`.

        Args:
            format: (:obj:`str`):
                The format of the desired pipeline. Acceptable values are :obj:`"json"`, :obj:`"csv"` or :obj:`"pipe"`.
            output_path (:obj:`str`, `optional`):
                Where to save the outgoing data.
            input_path (:obj:`str`, `optional`):
                Where to look for the input data.
            column (:obj:`str`, `optional`):
                The column to read.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to overwrite the :obj:`output_path`.

        Returns:
            :class:`~transformers.pipelines.PipelineDataFormat`: The proper data format.
        """
283
        if format == "json":
thomwolf's avatar
thomwolf committed
284
            return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
285
        elif format == "csv":
thomwolf's avatar
thomwolf committed
286
            return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
287
        elif format == "pipe":
thomwolf's avatar
thomwolf committed
288
            return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
289
        else:
290
            raise KeyError("Unknown reader {} (Available reader are json/csv/pipe)".format(format))
291
292
293


class CsvPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
294
295
296
297
298
299
300
301
302
303
304
    """
    Support for pipelines using CSV data format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

305
306
307
    def __init__(
        self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
    ):
thomwolf's avatar
thomwolf committed
308
        super().__init__(output_path, input_path, column, overwrite=overwrite)
309
310

    def __iter__(self):
311
        with open(self.input_path, "r") as f:
312
313
314
315
316
            reader = csv.DictReader(f)
            for row in reader:
                if self.is_multi_columns:
                    yield {k: row[c] for k, c in self.column}
                else:
317
                    yield row[self.column[0]]
318
319

    def save(self, data: List[dict]):
Sylvain Gugger's avatar
Sylvain Gugger committed
320
321
322
323
324
325
326
        """
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`List[dict]`): The data to store.
        """
327
        with open(self.output_path, "w") as f:
328
329
330
331
332
333
334
            if len(data) > 0:
                writer = csv.DictWriter(f, list(data[0].keys()))
                writer.writeheader()
                writer.writerows(data)


class JsonPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
335
336
337
338
339
340
341
342
343
344
345
    """
    Support for pipelines using JSON file format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

346
347
348
    def __init__(
        self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
    ):
thomwolf's avatar
thomwolf committed
349
        super().__init__(output_path, input_path, column, overwrite=overwrite)
350

351
        with open(input_path, "r") as f:
352
353
354
355
356
357
358
            self._entries = json.load(f)

    def __iter__(self):
        for entry in self._entries:
            if self.is_multi_columns:
                yield {k: entry[c] for k, c in self.column}
            else:
359
                yield entry[self.column[0]]
360
361

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
362
363
364
365
366
367
        """
        Save the provided data object in a json file.

        Args:
            data (:obj:`dict`): The data to store.
        """
368
        with open(self.output_path, "w") as f:
369
370
371
            json.dump(data, f)


Morgan Funtowicz's avatar
Morgan Funtowicz committed
372
373
374
375
376
377
class PipedPipelineDataFormat(PipelineDataFormat):
    """
    Read data from piped input to the python process.
    For multi columns data, columns should separated by \t

    If columns are provided, then the output will be a dictionary with {column_x: value_x}
Sylvain Gugger's avatar
Sylvain Gugger committed
378
379
380
381
382
383
384

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
385
    """
386

Morgan Funtowicz's avatar
Morgan Funtowicz committed
387
388
389
    def __iter__(self):
        for line in sys.stdin:
            # Split for multi-columns
390
            if "\t" in line:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
391

392
                line = line.split("\t")
Morgan Funtowicz's avatar
Morgan Funtowicz committed
393
394
395
396
397
398
399
400
401
402
403
                if self.column:
                    # Dictionary to map arguments
                    yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
                else:
                    yield tuple(line)

            # No dictionary to map arguments
            else:
                yield line

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
404
405
406
407
408
409
        """
        Print the data.

        Args:
            data (:obj:`dict`): The data to store.
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
410
411
        print(data)

412
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
thomwolf's avatar
thomwolf committed
413
        if self.output_path is None:
414
            raise KeyError(
415
416
                "When using piped input on pipeline outputting large object requires an output file path. "
                "Please provide such output path through --output argument."
417
418
419
420
            )

        return super().save_binary(data)

Morgan Funtowicz's avatar
Morgan Funtowicz committed
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

class _ScikitCompat(ABC):
    """
    Interface layer for the Scikit and Keras compatibility.
    """

    @abstractmethod
    def transform(self, X):
        raise NotImplementedError()

    @abstractmethod
    def predict(self, X):
        raise NotImplementedError()


Sylvain Gugger's avatar
Sylvain Gugger committed
436
PIPELINE_INIT_ARGS = r"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
437
    Arguments:
438
439
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
440
441
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
442
443
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
444
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
445
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
446
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
447
448
449
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
450
451

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
452
453
454
455
456
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
457
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
458
459
        device (:obj:`int`, `optional`, defaults to -1):
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model
Morgan Funtowicz's avatar
Morgan Funtowicz committed
460
            on the associated CUDA device id.
Lysandre Debut's avatar
Lysandre Debut committed
461
        binary_output (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
462
463
            Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
464

Lysandre Debut's avatar
Lysandre Debut committed
465

Sylvain Gugger's avatar
Sylvain Gugger committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
@add_end_docstrings(PIPELINE_INIT_ARGS)
class Pipeline(_ScikitCompat):
    """
    The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
    different pipelines.

    Base class implementing pipelined operations.
    Pipeline workflow is defined as a sequence of the following operations:

        Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output

    Pipeline supports running on CPU or GPU through the device argument (see below).

    Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` )
    output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
    provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the
    pickle format.
483
    """
thomwolf's avatar
thomwolf committed
484
485
486

    default_input_names = None

487
488
    def __init__(
        self,
489
490
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
491
        modelcard: Optional[ModelCard] = None,
492
        framework: Optional[str] = None,
493
        task: str = "",
494
495
496
497
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
    ):
498

thomwolf's avatar
thomwolf committed
499
        if framework is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
500
            framework = get_framework(model)
thomwolf's avatar
thomwolf committed
501

502
        self.task = task
503
504
        self.model = model
        self.tokenizer = tokenizer
505
        self.modelcard = modelcard
thomwolf's avatar
thomwolf committed
506
        self.framework = framework
507
        self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
508
        self.binary_output = binary_output
509
510
        self._args_parser = args_parser or DefaultArgumentHandler()

511
        # Special handling
512
513
        if self.framework == "pt" and self.device.type == "cuda":
            self.model = self.model.to(self.device)
514

515
516
517
518
519
        # Update config with task specific parameters
        task_specific_params = self.model.config.task_specific_params
        if task_specific_params is not None and task in task_specific_params:
            self.model.config.update(task_specific_params.get(task))

Sylvain Gugger's avatar
Sylvain Gugger committed
520
    def save_pretrained(self, save_directory: str):
521
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
522
523
524
525
526
        Save the pipeline's model and tokenizer.

        Args:
            save_directory (:obj:`str`):
                A path to the directory where to saved. It will be created if it doesn't exist.
527
        """
528
529
        if os.path.isfile(save_directory):
            logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
530
            return
531
        os.makedirs(save_directory, exist_ok=True)
532
533
534

        self.model.save_pretrained(save_directory)
        self.tokenizer.save_pretrained(save_directory)
535
536
        if self.modelcard is not None:
            self.modelcard.save_pretrained(save_directory)
537
538

    def transform(self, X):
539
540
541
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
542
543
544
        return self(X=X)

    def predict(self, X):
545
546
547
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
548
        return self(X=X)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
549

550
551
    @contextmanager
    def device_placement(self):
552
553
        """
        Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
Sylvain Gugger's avatar
Sylvain Gugger committed
554

555
556
        Returns:
            Context manager
Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
559
560
561
562
563
564

        Examples::

            # Explicitly ask for tensor allocation on CUDA device :0
            pipe = pipeline(..., device=0)
            with pipe.device_placement():
                # Every framework specific tensor allocation will be done on the request device
                output = pipe(...)
565
        """
566
567
        if self.framework == "tf":
            with tf.device("/CPU:0" if self.device == -1 else "/device:GPU:{}".format(self.device)):
568
569
                yield
        else:
570
            if self.device.type == "cuda":
571
                torch.cuda.set_device(self.device)
572

573
            yield
574

575
576
577
    def ensure_tensor_on_device(self, **inputs):
        """
        Ensure PyTorch tensors are on the specified device.
Sylvain Gugger's avatar
Sylvain Gugger committed
578
579
580
581
582
583

        Args:
            inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.

        Return:
            :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
584
585
586
        """
        return {name: tensor.to(self.device) for name, tensor in inputs.items()}

Sylvain Gugger's avatar
Sylvain Gugger committed
587
    def check_model_type(self, supported_models: Union[List[str], dict]):
588
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
589
590
591
592
593
        Check if the model class is in supported by the pipeline.

        Args:
            supported_models (:obj:`List[str]` or :obj:`dict`):
                The list of models supported by the pipeline, or a dictionary with model class values.
594
595
596
597
598
599
600
601
602
603
        """
        if not isinstance(supported_models, list):  # Create from a model mapping
            supported_models = [item[1].__name__ for item in supported_models.items()]
        if self.model.__class__.__name__ not in supported_models:
            raise PipelineException(
                self.task,
                self.model.base_model_prefix,
                f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
            )

604
    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
605
606
607
        """
        Parse arguments and tokenize
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
608
        # Parse arguments
609
        inputs = self._args_parser(*args, **kwargs)
610
611
        inputs = self.tokenizer(
            inputs, add_special_tokens=add_special_tokens, return_tensors=self.framework, padding=padding,
612
        )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
613

Julien Chaumond's avatar
Julien Chaumond committed
614
615
        return inputs

616
617
    def __call__(self, *args, **kwargs):
        inputs = self._parse_and_tokenize(*args, **kwargs)
618
        return self._forward(inputs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
619

Julien Chaumond's avatar
Julien Chaumond committed
620
    def _forward(self, inputs, return_tensors=False):
621
622
623
624
        """
        Internal framework specific forward dispatching.
        Args:
            inputs: dict holding all the keyworded arguments for required by the model forward method.
Julien Chaumond's avatar
Julien Chaumond committed
625
            return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array.
626
627
628
        Returns:
            Numpy array
        """
629
630
631
632
        # Encode for forward
        with self.device_placement():
            if self.framework == "tf":
                # TODO trace model
Funtowicz Morgan's avatar
Funtowicz Morgan committed
633
                predictions = self.model(inputs.data, training=False)[0]
634
635
636
637
            else:
                with torch.no_grad():
                    inputs = self.ensure_tensor_on_device(**inputs)
                    predictions = self.model(**inputs)[0].cpu()
638

Julien Chaumond's avatar
Julien Chaumond committed
639
640
641
642
        if return_tensors:
            return predictions
        else:
            return predictions.numpy()
643
644


Sylvain Gugger's avatar
Sylvain Gugger committed
645
# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output`
646
class FeatureExtractionPipeline(Pipeline):
647
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
648
649
    Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
    transformer, which can be used as features in downstream tasks.
Lysandre Debut's avatar
Lysandre Debut committed
650

Sylvain Gugger's avatar
Sylvain Gugger committed
651
652
    This feature extraction pipeline can currently be loaded from :func:`~transformers.pipeline` using the task
    identifier: :obj:`"feature-extraction"`.
Lysandre Debut's avatar
Lysandre Debut committed
653
654
655
656
657

    All models may be used for this pipeline. See a list of all models, including community-contributed models on
    `huggingface.co/models <https://huggingface.co/models>`__.

    Arguments:
658
659
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
660
661
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
662
663
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
664
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
665
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
666
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
667
668
669
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
670
671

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
672
673
674
675
676
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
677
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
678
679
        device (:obj:`int`, `optional`, defaults to -1):
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model
Lysandre Debut's avatar
Lysandre Debut committed
680
            on the associated CUDA device id.
681
    """
682

683
684
    def __init__(
        self,
685
686
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
687
        modelcard: Optional[ModelCard] = None,
688
689
690
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
691
        task: str = "",
692
693
694
695
696
697
698
699
700
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
701
            task=task,
702
        )
703

704
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
705
706
707
708
709
710
711
712
713
        """
        Extract the features of the input(s).

        Args:
            args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of texts) to get the features of.

        Return:
            A nested list of :obj:`float`: The features computed by the model.
        """
714
        return super().__call__(*args, **kwargs).tolist()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
715
716


Sylvain Gugger's avatar
Sylvain Gugger committed
717
@add_end_docstrings(PIPELINE_INIT_ARGS)
718
719
class TextGenerationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
720
721
    Language generation pipeline using any :obj:`ModelWithLMHead`. This pipeline predicts the words that will follow a
    specified text prompt.
722

Sylvain Gugger's avatar
Sylvain Gugger committed
723
724
    This language generation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"text-generation"`.
725

Sylvain Gugger's avatar
Sylvain Gugger committed
726
727
    The models that this pipeline can use are models that have been trained with an autoregressive language modeling
    objective, which includes the uni-directional models in the library (e.g. gpt2).
728
729
730
731
732
733
734
    See the list of available community models on
    `huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
    """

    # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
    # in https://github.com/rusiaaman/XLNet-gen#methodology
    # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
735

736
737
738
739
740
741
742
743
744
    PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
    (except for Alexei and Maria) are discovered.
    The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
    remainder of the story. 1883 Western Siberia,
    a young Grigori Rasputin is asked by his father and a group of men to perform magic.
    Rasputin has a vision and denounces one of the men as a horse thief. Although his
    father initially slaps him for making such an accusation, Rasputin watches as the
    man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
    the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
745
    with people, even a bishop, begging for his blessing. """
746

747
748
749
750
751
752
753
754
755
756
757
758
759
760
    ALLOWED_MODELS = [
        "XLNetLMHeadModel",
        "TransfoXLLMHeadModel",
        "ReformerModelWithLMHead",
        "GPT2LMHeadModel",
        "OpenAIGPTLMHeadModel",
        "CTRLLMHeadModel",
        "TFXLNetLMHeadModel",
        "TFTransfoXLLMHeadModel",
        "TFGPT2LMHeadModel",
        "TFOpenAIGPTLMHeadModel",
        "TFCTRLLMHeadModel",
    ]

761
762
763
764
765
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.check_model_type(self.ALLOWED_MODELS)

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
    # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments

    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
        """
        Parse arguments and tokenize
        """
        # Parse arguments
        if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
            tokenizer_kwargs = {"add_space_before_punct_symbol": True}
        else:
            tokenizer_kwargs = {}
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer(
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            **tokenizer_kwargs,
        )

        return inputs

788
    def __call__(
789
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
790
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
        """
        Complete the prompt(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several prompts (or one list of prompts) to complete.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
810

Sylvain Gugger's avatar
Sylvain Gugger committed
811
812
813
814
            - **generated_text** (:obj:`str`, present when ``return_text=True``) -- The generated text.
            - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the generated text.
        """
815
        text_inputs = self._args_parser(*args)
816
817
818
819
820

        results = []
        for prompt_text in text_inputs:
            # Manage correct placement of the tensors
            with self.device_placement():
821
822
823
824
825
826
                if self.model.__class__.__name__ in [
                    "XLNetLMHeadModel",
                    "TransfoXLLMHeadModel",
                    "TFXLNetLMHeadModel",
                    "TFTransfoXLLMHeadModel",
                ]:
827
828
829
830
831
832
833
834
835
836
                    # For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
                    padding_text = self.PADDING_TEXT + self.tokenizer.eos_token
                    padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False)
                    # This impacts max_length and min_length argument that need adjusting.
                    padding_length = padding["input_ids"].shape[-1]
                    if "max_length" in generate_kwargs and generate_kwargs["max_length"] is not None:
                        generate_kwargs["max_length"] += padding_length
                    if "min_length" in generate_kwargs and generate_kwargs["min_length"] is not None:
                        generate_kwargs["min_length"] += padding_length

837
                    inputs = self._parse_and_tokenize(
838
                        padding_text + prompt_text, padding=False, add_special_tokens=False
839
                    )
840
                else:
841
                    inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False)
842

843
844
845
846
847
848
                # set input_ids to None to allow empty prompt
                if inputs["input_ids"].shape[-1] == 0:
                    inputs["input_ids"] = None
                    inputs["attention_mask"] = None

                if self.framework == "pt" and inputs["input_ids"] is not None:
849
850
851
852
853
854
                    inputs = self.ensure_tensor_on_device(**inputs)

                input_ids = inputs["input_ids"]

                # Ensure that batch size = 1 (batch generation not allowed for now)
                assert (
855
                    input_ids is None or input_ids.shape[0] == 1
856
857
858
859
860
861
                ), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."

                output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs)  # BS x SL

            result = []
            for generated_sequence in output_sequences:
862
863
                if self.framework == "pt" and generated_sequence is not None:
                    generated_sequence = generated_sequence.cpu()
864
                generated_sequence = generated_sequence.numpy().tolist()
865
866
867
868
869
870
871
872
873
874
875
876
                record = {}
                if return_tensors:
                    record["generated_token_ids"] = generated_sequence
                if return_text:
                    # Decode text
                    text = self.tokenizer.decode(
                        generated_sequence,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )

                    # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
877
878
879
880
881
882
883
884
885
886
887
888
                    if input_ids is None:
                        prompt_length = 0
                    else:
                        prompt_length = len(
                            self.tokenizer.decode(
                                input_ids[0],
                                skip_special_tokens=True,
                                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                            )
                        )

                    record["generated_text"] = prompt_text + text[prompt_length:]
889
890
891
892
893
894
895
896
897
898

                result.append(record)
            results += [result]

        if len(results) == 1:
            return results[0]

        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
899
900
901
902
903
904
905
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        return_all_scores (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether to return all prediction scores or just the one of the predicted class.
    """,
)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
906
class TextClassificationPipeline(Pipeline):
907
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
908
909
    Text classification pipeline using any :obj:`ModelForSequenceClassification`. See the
    `sequence classification examples <../task_summary.html#sequence-classification>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
910

Sylvain Gugger's avatar
Sylvain Gugger committed
911
912
913
    This text classification pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"sentiment-analysis"` (for classifying sequences according to positive or negative
    sentiments).
Lysandre Debut's avatar
Lysandre Debut committed
914
915

    The models that this pipeline can use are models that have been fine-tuned on a sequence classification task.
916
917
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=text-classification>`__.
918
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
919

920
921
922
    def __init__(self, return_all_scores: bool = False, **kwargs):
        super().__init__(**kwargs)

923
924
925
926
927
928
        self.check_model_type(
            TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
        )

929
930
        self.return_all_scores = return_all_scores

931
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
        """
        Classify the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several textts (or one list of prompts) to classify.

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
            following keys:

            - **label** (:obj:`str`) -- The label predicted.
            - **score** (:obj:`float`) -- The corresponding probability.

            If ``self.return_all_scores=True``, one such dictionary is returned per label.
        """
948
        outputs = super().__call__(*args, **kwargs)
Zhiyu Lin's avatar
Zhiyu Lin committed
949
        scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
950
951
        if self.return_all_scores:
            return [
952
                [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(item)]
953
954
955
956
957
958
                for item in scores
            ]
        else:
            return [
                {"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores
            ]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
959
960


961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
class ZeroShotClassificationArgumentHandler(ArgumentHandler):
    """
    Handles arguments for zero-shot for text classification by turning each possible label into an NLI
    premise/hypothesis pair.
    """

    def _parse_labels(self, labels):
        if isinstance(labels, str):
            labels = [label.strip() for label in labels.split(",")]
        return labels

    def __call__(self, sequences, labels, hypothesis_template):
        if len(labels) == 0 or len(sequences) == 0:
            raise ValueError("You must include at least one label and at least one sequence.")
        if hypothesis_template.format(labels[0]) == hypothesis_template:
            raise ValueError(
                (
                    'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
                    "Make sure the passed template includes formatting syntax such as {{}} where the label should go."
                ).format(hypothesis_template)
            )

        if isinstance(sequences, str):
            sequences = [sequences]
        labels = self._parse_labels(labels)

        sequence_pairs = []
        for sequence in sequences:
            sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])

        return sequence_pairs


Sylvain Gugger's avatar
Sylvain Gugger committed
994
@add_end_docstrings(PIPELINE_INIT_ARGS)
995
996
class ZeroShotClassificationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
997
998
    NLI-based zero-shot classification pipeline using a :obj:`ModelForSequenceClassification` trained on NLI (natural
    language inference) tasks.
999
1000

    Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
Sylvain Gugger's avatar
Sylvain Gugger committed
1001
    pair and passed to the pretrained model. Then, the logit for `entailment` is taken as the logit for the
1002
1003
1004
    candidate label being valid. Any NLI model can be used as long as the first output logit corresponds to
    `contradiction` and the last to `entailment`.

Sylvain Gugger's avatar
Sylvain Gugger committed
1005
1006
    This NLI pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"zero-shot-classification"`.
1007

Sylvain Gugger's avatar
Sylvain Gugger committed
1008
    The models that this pipeline can use are models that have been fine-tuned on an NLI task.
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?search=nli>`__.
    """

    def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
        super().__init__(*args, args_parser=args_parser, **kwargs)

    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
        """
        Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
        """
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer(
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            truncation="only_first",
        )

        return inputs

    def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1033
        Classify the sequence(s) given as inputs.
1034
1035

        Args:
1036
            sequences (:obj:`str` or :obj:`List[str]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1037
                The sequence(s) to classify, will be truncated if the model input is too large.
1038
            candidate_labels (:obj:`str` or :obj:`List[str]`):
1039
1040
                The set of possible class labels to classify each sequence into. Can be a single label, a string of
                comma-separated labels, or a list of labels.
1041
            hypothesis_template (:obj:`str`, `optional`, defaults to :obj:`"This example is {}."`):
1042
1043
                The template used to turn each label into an NLI-style hypothesis. This template must include a {}
                or similar syntax for the candidate label to be inserted into the template. For example, the default
Sylvain Gugger's avatar
Sylvain Gugger committed
1044
1045
1046
1047
                template is :obj:`"This example is {}."` With the candidate label :obj:`"sports"`, this would be fed
                into the model like :obj:`"<cls> sequence to classify <sep> This example is sports . <sep>"`. The
                default template works well in many cases, but it may be worthwhile to experiment with different
                templates depending on the task setting.
1048
            multi_class (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1049
1050
1051
                Whether or not multiple candidate labels can be true. If :obj:`False`, the scores are normalized
                such that the sum of the label likelihoods for each sequence is 1. If :obj:`True`, the labels are
                considered independent and probabilities are normalized for each candidate by doing a softmax of
1052
                the entailment score vs. the contradiction score.
1053

Sylvain Gugger's avatar
Sylvain Gugger committed
1054
1055
1056
1057
1058
1059
        Return:
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:

            - **sequence** (:obj:`str`) -- The sequence for which this is the output.
            - **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
1060
            - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
        """
        outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
        num_sequences = 1 if isinstance(sequences, str) else len(sequences)
        candidate_labels = self._args_parser._parse_labels(candidate_labels)
        reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))

        if len(candidate_labels) == 1:
            multi_class = True

        if not multi_class:
            # softmax the "entailment" logits over all candidate labels
            entail_logits = reshaped_outputs[..., -1]
            scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
        else:
            # softmax over the entailment vs. contradiction dim for each label independently
            entail_contr_logits = reshaped_outputs[..., [0, -1]]
            scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
            scores = scores[..., 1]

        result = []
        for iseq in range(num_sequences):
            top_inds = list(reversed(scores[iseq].argsort()))
            result.append(
                {
1085
                    "sequence": sequences if isinstance(sequences, str) else sequences[iseq],
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
                    "labels": [candidate_labels[i] for i in top_inds],
                    "scores": scores[iseq][top_inds].tolist(),
                }
            )

        if len(result) == 1:
            return result[0]
        return result


Sylvain Gugger's avatar
Sylvain Gugger committed
1096
1097
1098
1099
1100
1101
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        topk (:obj:`int`, defaults to 5): The number of predictions to return.
    """,
)
Julien Chaumond's avatar
Julien Chaumond committed
1102
1103
class FillMaskPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1104
1105
    Masked language modeling prediction pipeline using any :obj:`ModelWithLMHead`. See the
    `masked language modeling examples <../task_summary.html#masked-language-modeling>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1106

Sylvain Gugger's avatar
Sylvain Gugger committed
1107
1108
    This mask filling pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"fill-mask"`.
Lysandre Debut's avatar
Lysandre Debut committed
1109
1110
1111

    The models that this pipeline can use are models that have been trained with a masked language modeling objective,
    which includes the bi-directional models in the library.
1112
1113
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=lm-head>`__.
Lysandre Debut's avatar
Lysandre Debut committed
1114

Sylvain Gugger's avatar
Sylvain Gugger committed
1115
    .. note::
Lysandre Debut's avatar
Lysandre Debut committed
1116

Sylvain Gugger's avatar
Sylvain Gugger committed
1117
        This pipeline only works for inputs with exactly one token masked.
Julien Chaumond's avatar
Julien Chaumond committed
1118
1119
1120
1121
    """

    def __init__(
        self,
1122
1123
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1124
        modelcard: Optional[ModelCard] = None,
Julien Chaumond's avatar
Julien Chaumond committed
1125
1126
1127
1128
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        topk=5,
1129
        task: str = "",
Julien Chaumond's avatar
Julien Chaumond committed
1130
1131
1132
1133
1134
1135
1136
1137
1138
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
1139
            task=task,
Julien Chaumond's avatar
Julien Chaumond committed
1140
1141
        )

1142
        self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
1143

Julien Chaumond's avatar
Julien Chaumond committed
1144
1145
        self.topk = topk

1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
        numel = np.prod(masked_index.shape)
        if numel > 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
            )
        elif numel < 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"No mask_token ({self.tokenizer.mask_token}) found on the input",
            )

Julien Chaumond's avatar
Julien Chaumond committed
1161
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
        """
        Fill the masked token in the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of prompts) with masked tokens.

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
            following keys:

            - **sequence** (:obj:`str`) -- The corresponding input with the mask token prediction.
            - **score** (:obj:`float`) -- The corresponding probability.
            - **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
            - **token** (:obj:`str`) -- The predicted token (to replace the masked one).
        """
Julien Chaumond's avatar
Julien Chaumond committed
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
        inputs = self._parse_and_tokenize(*args, **kwargs)
        outputs = self._forward(inputs, return_tensors=True)

        results = []
        batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)

        for i in range(batch_size):
            input_ids = inputs["input_ids"][i]
            result = []

            if self.framework == "tf":
1188
1189
1190
1191
1192
1193
                masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index)

                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1194
1195
1196
1197
                probs = tf.nn.softmax(logits)
                topk = tf.math.top_k(probs, k=self.topk)
                values, predictions = topk.values.numpy(), topk.indices.numpy()
            else:
1198
1199
1200
1201
                masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero()

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index.numpy())
1202

1203
                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1204
1205
1206
1207
1208
1209
1210
1211
                probs = logits.softmax(dim=0)
                values, predictions = probs.topk(self.topk)

            for v, p in zip(values.tolist(), predictions.tolist()):
                tokens = input_ids.numpy()
                tokens[masked_index] = p
                # Filter padding out:
                tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
1212
1213
1214
1215
1216
1217
1218
1219
                result.append(
                    {
                        "sequence": self.tokenizer.decode(tokens),
                        "score": v,
                        "token": p,
                        "token_str": self.tokenizer.convert_ids_to_tokens(p),
                    }
                )
Julien Chaumond's avatar
Julien Chaumond committed
1220
1221
1222
1223
1224
1225
1226
1227
1228

            # Append
            results += [result]

        if len(results) == 1:
            return results[0]
        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
1229
1230
1231
1232
1233
1234
1235
1236
1237
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`):
            A list of labels to ignore.
        grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to group the tokens corresponding to the same entity together in the predictions or not.
    """,
)
1238
class TokenClassificationPipeline(Pipeline):
1239
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1240
1241
    Named Entity Recognition pipeline using any :obj:`ModelForTokenClassification`. See the
    `named entity recognition examples <../task_summary.html#named-entity-recognition>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1242

Sylvain Gugger's avatar
Sylvain Gugger committed
1243
1244
1245
    This token recognition pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location
    or miscellaneous).
Lysandre Debut's avatar
Lysandre Debut committed
1246
1247

    The models that this pipeline can use are models that have been fine-tuned on a token classification task.
1248
1249
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=token-classification>`__.
1250
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1251

1252
1253
1254
1255
    default_input_names = "sequences"

    def __init__(
        self,
1256
1257
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1258
        modelcard: Optional[ModelCard] = None,
1259
1260
1261
1262
1263
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
        ignore_labels=["O"],
1264
        task: str = "",
1265
        grouped_entities: bool = False,
1266
1267
1268
1269
1270
1271
1272
1273
1274
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=binary_output,
1275
            task=task,
1276
        )
1277

1278
1279
1280
1281
1282
1283
        self.check_model_type(
            TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
        )

1284
        self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
1285
        self.ignore_labels = ignore_labels
1286
        self.grouped_entities = grouped_entities
1287

1288
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
        """
        Classify each token of the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several texts (or one list of texts) for token classification.

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in
            the corresponding input, or each entity if this pipeline was instantiated with
            :obj:`grouped_entities=True`) with the following keys:

            - **word** (:obj:`str`) -- The token/word classified.
            - **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
            - **entity** (:obj:`str`) -- The entity predicted for that token/word.
            - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
              corresponding token in the sentence.
        """
1307
        inputs = self._args_parser(*args, **kwargs)
Julien Chaumond's avatar
Julien Chaumond committed
1308
        answers = []
1309
        for sentence in inputs:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1310

1311
1312
            # Manage correct placement of the tensors
            with self.device_placement():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1313

1314
1315
                tokens = self.tokenizer(
                    sentence, return_attention_mask=False, return_tensors=self.framework, truncation=True,
1316
                )
1317
1318

                # Forward
1319
                if self.framework == "tf":
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1320
                    entities = self.model(tokens.data)[0][0].numpy()
1321
                    input_ids = tokens["input_ids"].numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1322
                else:
1323
                    with torch.no_grad():
1324
                        tokens = self.ensure_tensor_on_device(**tokens)
1325
                        entities = self.model(**tokens)[0][0].cpu().numpy()
1326
                        input_ids = tokens["input_ids"].cpu().numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1327

thomwolf's avatar
thomwolf committed
1328
1329
            score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
            labels_idx = score.argmax(axis=-1)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1330

1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
            entities = []
            # Filter to labels not in `self.ignore_labels`
            filtered_labels_idx = [
                (idx, label_idx)
                for idx, label_idx in enumerate(labels_idx)
                if self.model.config.id2label[label_idx] not in self.ignore_labels
            ]

            for idx, label_idx in filtered_labels_idx:

                entity = {
                    "word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
                    "score": score[idx][label_idx].item(),
                    "entity": self.model.config.id2label[label_idx],
                    "index": idx,
                }

                entities += [entity]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1349

1350
            # Append grouped entities
1351
            if self.grouped_entities:
1352
1353
                answers += [self.group_entities(entities)]
            # Append ungrouped entities
1354
1355
1356
            else:
                answers += [entities]

thomwolf's avatar
thomwolf committed
1357
1358
        if len(answers) == 1:
            return answers[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1359
1360
        return answers

1361
    def group_sub_entities(self, entities: List[dict]) -> dict:
1362
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1363
1364
1365
1366
        Group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1367
        """
1368
1369
        # Get the first entity in the entity group
        entity = entities[0]["entity"]
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
        scores = np.mean([entity["score"] for entity in entities])
        tokens = [entity["word"] for entity in entities]

        entity_group = {
            "entity_group": entity,
            "score": np.mean(scores),
            "word": self.tokenizer.convert_tokens_to_string(tokens),
        }
        return entity_group

1380
1381
    def group_entities(self, entities: List[dict]) -> List[dict]:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1382
1383
1384
1385
        Find and group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
        """

        entity_groups = []
        entity_group_disagg = []

        if entities:
            last_idx = entities[-1]["index"]

        for entity in entities:
            is_last_idx = entity["index"] == last_idx
            if not entity_group_disagg:
                entity_group_disagg += [entity]
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
                continue

            # If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
            # The split is meant to account for the "B" and "I" suffixes
            if (
                entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
                and entity["index"] == entity_group_disagg[-1]["index"] + 1
            ):
                entity_group_disagg += [entity]
                # Group the entities at the last entity
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
            # If the current entity is different from the previous entity, aggregate the disaggregated entity group
            else:
                entity_groups += [self.group_sub_entities(entity_group_disagg)]
                entity_group_disagg = [entity]
                # If it's the last entity, add it to the entity groups
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]

        return entity_groups

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1422

1423
NerPipeline = TokenClassificationPipeline
1424
1425


1426
1427
1428
class QuestionAnsweringArgumentHandler(ArgumentHandler):
    """
    QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped
Sylvain Gugger's avatar
Sylvain Gugger committed
1429
    to internal :class:`~transformers.SquadExample`.
1430

Sylvain Gugger's avatar
Sylvain Gugger committed
1431
1432
    QuestionAnsweringArgumentHandler manages all the possible to create a :class:`~transformers.SquadExample` from
    the command-line supplied arguments.
1433
    """
1434

1435
1436
1437
1438
    def __call__(self, *args, **kwargs):
        # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
        if args is not None and len(args) > 0:
            if len(args) == 1:
1439
                kwargs["X"] = args[0]
1440
            else:
1441
                kwargs["X"] = list(args)
1442

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1443
1444
        # Generic compatibility with sklearn and Keras
        # Batched data
1445
1446
        if "X" in kwargs or "data" in kwargs:
            inputs = kwargs["X"] if "X" in kwargs else kwargs["data"]
1447

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1448
1449
1450
1451
1452
            if isinstance(inputs, dict):
                inputs = [inputs]
            else:
                # Copy to avoid overriding arguments
                inputs = [i for i in inputs]
1453

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1454
            for i, item in enumerate(inputs):
1455
                if isinstance(item, dict):
1456
1457
                    if any(k not in item for k in ["question", "context"]):
                        raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
1458

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1459
1460
1461
                    inputs[i] = QuestionAnsweringPipeline.create_sample(**item)

                elif not isinstance(item, SquadExample):
1462
                    raise ValueError(
1463
1464
1465
                        "{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format(
                            "X" if "X" in kwargs else "data"
                        )
1466
1467
1468
                    )

            # Tabular input
1469
1470
1471
        elif "question" in kwargs and "context" in kwargs:
            if isinstance(kwargs["question"], str):
                kwargs["question"] = [kwargs["question"]]
1472

1473
1474
            if isinstance(kwargs["context"], str):
                kwargs["context"] = [kwargs["context"]]
1475

1476
1477
1478
            inputs = [
                QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"])
            ]
1479
        else:
1480
            raise ValueError("Unknown arguments {}".format(kwargs))
1481
1482
1483
1484
1485
1486
1487

        if not isinstance(inputs, list):
            inputs = [inputs]

        return inputs


Sylvain Gugger's avatar
Sylvain Gugger committed
1488
@add_end_docstrings(PIPELINE_INIT_ARGS)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1489
1490
class QuestionAnsweringPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1491
1492
    Question Answering pipeline using any :obj:`ModelForQuestionAnswering`. See the
    `question answering examples <../task_summary.html#question-answering>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1493

Sylvain Gugger's avatar
Sylvain Gugger committed
1494
1495
    This question answering pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"question-answering"`.
Lysandre Debut's avatar
Lysandre Debut committed
1496
1497

    The models that this pipeline can use are models that have been fine-tuned on a question answering task.
1498
1499
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=question-answering>`__.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1500
1501
    """

1502
1503
1504
1505
    default_input_names = "question,context"

    def __init__(
        self,
1506
1507
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1508
        modelcard: Optional[ModelCard] = None,
1509
1510
        framework: Optional[str] = None,
        device: int = -1,
1511
        task: str = "",
1512
1513
1514
1515
1516
1517
1518
1519
1520
        **kwargs
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=QuestionAnsweringArgumentHandler(),
            device=device,
1521
            task=task,
1522
            **kwargs,
1523
        )
thomwolf's avatar
thomwolf committed
1524

1525
1526
1527
1528
        self.check_model_type(
            TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
        )

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1529
    @staticmethod
1530
1531
1532
    def create_sample(
        question: Union[str, List[str]], context: Union[str, List[str]]
    ) -> Union[SquadExample, List[SquadExample]]:
1533
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1534
1535
1536
1537
        QuestionAnsweringPipeline leverages the :class:`~transformers.SquadExample` internally.
        This helper method encapsulate all the logic for converting question(s) and context(s) to
        :class:`~transformers.SquadExample`.

1538
        We currently support extractive question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
1539

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1540
        Arguments:
Sylvain Gugger's avatar
Sylvain Gugger committed
1541
1542
            question (:obj:`str` or :obj:`List[str]`): The question(s) asked.
            context (:obj:`str` or :obj:`List[str]`): The context(s) in which we will look for the answer.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1543
1544

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1545
1546
            One or a list of :class:`~transformers.SquadExample`: The corresponding
            :class:`~transformers.SquadExample` grouping question and context.
1547
1548
        """
        if isinstance(question, list):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1549
1550
1551
1552
            return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
        else:
            return SquadExample(None, question, context, None, None, None)

1553
    def __call__(self, *args, **kwargs):
1554
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1555
1556
        Answer the question(s) given as inputs by using the context(s).

1557
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
            args (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`):
                One or several :class:`~transformers.SquadExample` containing the question and context.
            X (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
                One or several :class:`~transformers.SquadExample` containing the question and context
                (will be treated the same way as if passed as the first positional argument).
            data (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
                One or several :class:`~transformers.SquadExample` containing the question and context
                (will be treated the same way as if passed as the first positional argument).
            question (:obj:`str` or :obj:`List[str]`):
                One or several question(s) (must be used in conjunction with the :obj:`context` argument).
            context (:obj:`str` or :obj:`List[str]`):
                One or several context(s) associated with the qustion(s) (must be used in conjunction with the
                :obj:`question` argument).
            topk (:obj:`int`, `optional`, defaults to 1):
                The number of answers to return (will be chosen by order of likelihood).
            doc_stride (:obj:`int`, `optional`, defaults to 128):
                If the context is too long to fit with the question for the model, it will be split in several chunks
                with some overlap. This argument controls the size of that overlap.
            max_answer_len (:obj:`int`, `optional`, defaults to 15):
                The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
            max_seq_len (:obj:`int`, `optional`, defaults to 384):
                The maximum length of the total sentence (context + question) after tokenization. The context will be
                split in several chunks (using :obj:`doc_stride`) if needed.
            max_question_len (:obj:`int`, `optional`, defaults to 64):
                The maximum length of the question after tokenization. It will be truncated if needed.
            handle_impossible_answer (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not we accept impossible as an answer.

        Return:
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:

            - **score** (:obj:`float`) -- The probability associated to the answer.
            - **start** (:obj:`int`) -- The start index of the answer (in the tokenized version of the input).
            - **end** (:obj:`int`) -- The end index of the answer (in the tokenized version of the input).
            - **answer** (:obj:`str`) -- The answer to the question.
1594
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1595
        # Set defaults values
1596
1597
1598
1599
1600
        kwargs.setdefault("topk", 1)
        kwargs.setdefault("doc_stride", 128)
        kwargs.setdefault("max_answer_len", 15)
        kwargs.setdefault("max_seq_len", 384)
        kwargs.setdefault("max_question_len", 64)
1601
        kwargs.setdefault("handle_impossible_answer", False)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1602

1603
1604
        if kwargs["topk"] < 1:
            raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1605

1606
1607
        if kwargs["max_answer_len"] < 1:
            raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1608
1609

        # Convert inputs to features
1610
        examples = self._args_parser(*args, **kwargs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1611
1612
        features_list = [
            squad_convert_examples_to_features(
1613
1614
1615
1616
1617
                examples=[example],
                tokenizer=self.tokenizer,
                max_seq_length=kwargs["max_seq_len"],
                doc_stride=kwargs["doc_stride"],
                max_query_length=kwargs["max_question_len"],
1618
                padding_strategy=PaddingStrategy.DO_NOT_PAD.value,
1619
                is_training=False,
1620
                tqdm_enabled=False,
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1621
1622
1623
            )
            for example in examples
        ]
Rishabh Manoj's avatar
Rishabh Manoj committed
1624
1625
        all_answers = []
        for features, example in zip(features_list, examples):
Patrick von Platen's avatar
Patrick von Platen committed
1626
1627
            model_input_names = self.tokenizer.model_input_names + ["input_ids"]
            fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
Rishabh Manoj's avatar
Rishabh Manoj committed
1628
1629
1630
1631
1632

            # Manage tensor allocation on correct device
            with self.device_placement():
                if self.framework == "tf":
                    fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1633
                    start, end = self.model(fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1634
1635
1636
1637
1638
                    start, end = start.numpy(), end.numpy()
                else:
                    with torch.no_grad():
                        # Retrieve the score for the context tokens only (removing question tokens)
                        fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1639
                        start, end = self.model(**fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1640
1641
                        start, end = start.cpu().numpy(), end.cpu().numpy()

1642
            min_null_score = 1000000  # large and positive
Rishabh Manoj's avatar
Rishabh Manoj committed
1643
1644
            answers = []
            for (feature, start_, end_) in zip(features, start, end):
1645
1646
                # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
                undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask
Rishabh Manoj's avatar
Rishabh Manoj committed
1647

1648
1649
1650
1651
1652
1653
                # Generate mask
                undesired_tokens_mask = undesired_tokens == 0.0

                # Make sure non-context indexes in the tensor cannot contribute to the softmax
                start_ = np.where(undesired_tokens_mask, -10000.0, start_)
                end_ = np.where(undesired_tokens_mask, -10000.0, end_)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1654
1655
1656
1657
1658

                # Normalize logits and spans to retrieve the answer
                start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
                end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

1659
1660
1661
                if kwargs["handle_impossible_answer"]:
                    min_null_score = min(min_null_score, (start_[0] * end_[0]).item())

1662
1663
1664
                # Mask CLS
                start_[0] = end_[0] = 0.0

Rishabh Manoj's avatar
Rishabh Manoj committed
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
                starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
                char_to_word = np.array(example.char_to_word_offset)

                # Convert the answer (tokens) back to the original text
                answers += [
                    {
                        "score": score.item(),
                        "start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
                        "end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
                        "answer": " ".join(
                            example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
                        ),
                    }
                    for s, e, score in zip(starts, ends, scores)
                ]
1680
1681
1682
1683

            if kwargs["handle_impossible_answer"]:
                answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1684
1685
1686
            answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
            all_answers += answers

Rishabh Manoj's avatar
Rishabh Manoj committed
1687
        if len(all_answers) == 1:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1688
            return all_answers[0]
Rishabh Manoj's avatar
Rishabh Manoj committed
1689
        return all_answers
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1690
1691

    def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
1692
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1693
        Take the output of any :obj:`ModelForQuestionAnswering` and will generate probalities for each span to be
1694
        the actual answer.
Sylvain Gugger's avatar
Sylvain Gugger committed
1695

1696
1697
1698
1699
1700
        In addition, it filters out some unwanted/impossible cases like answer len being greater than
        max_answer_len or answer end position being before the starting position.
        The method supports output the k-best answer through the topk argument.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1701
1702
1703
1704
            start (:obj:`np.ndarray`): Individual start probabilities for each token.
            end (:obj:`np.ndarray`): Individual end probabilities for each token.
            topk (:obj:`int`): Indicates how many possible answer span(s) to extract from the model output.
            max_answer_len (:obj:`int`): Maximum size of the answer to extract from the model's output.
1705
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
        # Ensure we have batch axis
        if start.ndim == 1:
            start = start[None]

        if end.ndim == 1:
            end = end[None]

        # Compute the score of each tuple(start, end) to be the real answer
        outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))

        # Remove candidate with end < start and end - start > max_answer_len
        candidates = np.tril(np.triu(outer), max_answer_len - 1)

        #  Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
        scores_flat = candidates.flatten()
        if topk == 1:
            idx_sort = [np.argmax(scores_flat)]
        elif len(scores_flat) < topk:
            idx_sort = np.argsort(-scores_flat)
        else:
            idx = np.argpartition(-scores_flat, topk)[0:topk]
            idx_sort = idx[np.argsort(-scores_flat[idx])]

        start, end = np.unravel_index(idx_sort, candidates.shape)[1:]
        return start, end, candidates[0, start, end]

Sylvain Gugger's avatar
Sylvain Gugger committed
1732
    def span_to_answer(self, text: str, start: int, end: int) -> Dict[str, Union[str, int]]:
1733
1734
1735
1736
1737
        """
        When decoding from token probalities, this method maps token indexes to actual word in
        the initial context.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1738
1739
1740
            text (:obj:`str`): The actual context to extract the answer from.
            start (:obj:`int`): The answer starting token index.
            end (:obj:`int`): The answer end token index.
1741
1742

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1743
            Dictionary like :obj:`{'answer': str, 'start': int, 'end': int}`
1744
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
        words = []
        token_idx = char_start_idx = char_end_idx = chars_idx = 0

        for i, word in enumerate(text.split(" ")):
            token = self.tokenizer.tokenize(word)

            # Append words if they are in the span
            if start <= token_idx <= end:
                if token_idx == start:
                    char_start_idx = chars_idx

                if token_idx == end:
                    char_end_idx = chars_idx + len(word)

                words += [word]

            # Stop if we went over the end of the answer
            if token_idx > end:
                break

            # Append the subtokenization length to the running index
            token_idx += len(token)
            chars_idx += len(word) + 1

        # Join text with spaces
1770
1771
1772
1773
1774
        return {
            "answer": " ".join(words),
            "start": max(0, char_start_idx),
            "end": min(len(text), char_end_idx),
        }
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1775
1776


Sylvain Gugger's avatar
Sylvain Gugger committed
1777
@add_end_docstrings(PIPELINE_INIT_ARGS)
1778
1779
class SummarizationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1780
1781
1782
1783
1784
1785
1786
1787
1788
    Summarize news articles and other documents.

    This summarizing pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"summarization"`.

    The models that this pipeline can use are models that have been fine-tuned on a summarization task,
    which is currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'.
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=summarization>`__.
1789
1790
1791

    Usage::

1792
        # use bart in pytorch
1793
        summarizer = pipeline("summarization")
1794
1795
1796
1797
1798
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)

        # use t5 in tf
        summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
1799
1800
    """

1801
    def __init__(self, *args, **kwargs):
1802
        kwargs.update(task="summarization")
1803
1804
1805
1806
1807
        super().__init__(*args, **kwargs)

        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
1808

1809
    def __call__(
1810
        self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
1811
1812
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1813
        Summarize the text(s) given as inputs.
1814

Sylvain Gugger's avatar
Sylvain Gugger committed
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
        Args:
            documents (`str` or :obj:`List[str]`):
                One or several articles (or one list of articles) to summarize.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
1827

Sylvain Gugger's avatar
Sylvain Gugger committed
1828
1829
1830
        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
1831

Sylvain Gugger's avatar
Sylvain Gugger committed
1832
1833
1834
1835
            - **summary_text** (:obj:`str`, present when ``return_text=True``) -- The summary of the corresponding
              input.
            - **summary_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the summary.
1836
1837
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
        assert len(documents) > 0, "Please provide a document to summarize"

        if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
            raise NotImplementedError(
                "Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
            )

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

        if isinstance(documents[0], list):
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"

            documents = ([prefix + document for document in documents[0]],)
1853
            padding = True
1854
1855
1856

        elif isinstance(documents[0], str):
            documents = (prefix + documents[0],)
1857
            padding = False
1858
1859
1860
1861
1862
1863
1864
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
                    documents[0]
                )
            )

1865
        with self.device_placement():
1866
            inputs = self._parse_and_tokenize(*documents, padding=padding)
1867
1868
1869
1870
1871

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]
            elif self.framework == "tf":
1872
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()
1873

1874
1875
            min_length = generate_kwargs.get("min_length", self.model.config.min_length)
            if input_length < min_length // 2:
1876
                logger.warning(
1877
                    "Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
1878
                        min_length, input_length
1879
1880
1881
                    )
                )

1882
1883
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length < max_length:
1884
                logger.warning(
1885
                    "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
1886
                        max_length, input_length
1887
1888
1889
                    )
                )

1890
            summaries = self.model.generate(
1891
                inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
1892
            )
1893

1894
1895
1896
1897
1898
1899
1900
            results = []
            for summary in summaries:
                record = {}
                if return_tensors:
                    record["summary_token_ids"] = summary
                if return_text:
                    record["summary_text"] = self.tokenizer.decode(
1901
1902
1903
1904
1905
1906
                        summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                results.append(record)
            return results


Sylvain Gugger's avatar
Sylvain Gugger committed
1907
@add_end_docstrings(PIPELINE_INIT_ARGS)
1908
1909
1910
1911
class TranslationPipeline(Pipeline):
    """
    Translates from one language to another.

Sylvain Gugger's avatar
Sylvain Gugger committed
1912
1913
    This translation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"translation_xx_to_yy"`.
1914

Sylvain Gugger's avatar
Sylvain Gugger committed
1915
    The models that this pipeline can use are models that have been fine-tuned on a translation task.
1916
1917
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=translation>`__.
1918

Sylvain Gugger's avatar
Sylvain Gugger committed
1919
1920
1921
    Usage::
        en_fr_translator = pipeline("translation_en_to_fr")
        en_fr_translator("How old are you?")
1922
1923
    """

1924
1925
1926
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

1927
1928
1929
        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
1930

1931
    def __call__(
1932
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
1933
1934
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1935
1936
        Translate the text(s) given as inputs.

1937
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
            args (:obj:`str` or :obj:`List[str]`):
                Texts to be translated.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
1949

Sylvain Gugger's avatar
Sylvain Gugger committed
1950
1951
1952
        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
1953

Sylvain Gugger's avatar
Sylvain Gugger committed
1954
1955
1956
            - **translation_text** (:obj:`str`, present when ``return_text=True``) -- The translation.
            - **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the translation.
1957
1958
1959
1960
1961
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

1962
        if isinstance(args[0], list):
1963
1964
1965
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
1966
            args = ([prefix + text for text in args[0]],)
1967
            padding = True
1968

1969
1970
        elif isinstance(args[0], str):
            args = (prefix + args[0],)
1971
            padding = False
1972
1973
1974
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
1975
                    args[0]
1976
1977
1978
1979
                )
            )

        with self.device_placement():
1980
            inputs = self._parse_and_tokenize(*args, padding=padding)
1981
1982
1983
1984
1985
1986
1987
1988

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

1989
1990
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length > 0.9 * max_length:
1991
1992
                logger.warning(
                    "Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
1993
                        input_length, max_length
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
                    )
                )

            translations = self.model.generate(
                inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
            )
            results = []
            for translation in translations:
                record = {}
                if return_tensors:
                    record["translation_token_ids"] = translation
                if return_text:
                    record["translation_text"] = self.tokenizer.decode(
                        translation,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
2010
2011
2012
2013
2014
                    )
                results.append(record)
            return results


2015
2016
2017
class Conversation:
    """
    Utility class containing a conversation and its history. This class is meant to be used as an input to the
Sylvain Gugger's avatar
Sylvain Gugger committed
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
    :class:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the
    addition of new user input and generated model responses. A conversation needs to contain an unprocessed user input
    before being passed to the :class:`~transformers.ConversationalPipeline`. This user input is either created when
    the class is instantiated, or by calling :obj:`conversional_pipeline.append_response("input")` after a conversation
    turn.

    Arguments:
        text (:obj:`str`, `optional`):
            The initial user input to start the conversation. If not provided, a user input needs to be provided
            manually using the :meth:`~transformers.Conversation.add_user_input` method before the conversation can
            begin.
        conversation_id (:obj:`uuid.UUID`, `optional`):
            Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
            conversation.
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056

    Usage::

        conversation = Conversation("Going to the movies tonight - any suggestions?")

        # Steps usually performed by the model when generating a response:
        # 1. Mark the user input as processed (moved to the history)
        conversation.mark_processed()
        # 2. Append a mode response
        conversation.append_response("The Big lebowski.")

        conversation.add_user_input("Is it good?")
    """

    def __init__(self, text: str = None, conversation_id: UUID = None):
        if not conversation_id:
            conversation_id = uuid.uuid4()
        self.uuid: UUID = conversation_id
        self.past_user_inputs: List[str] = []
        self.generated_responses: List[str] = []
        self.history: List[int] = []
        self.new_user_input: Optional[str] = text

    def add_user_input(self, text: str, overwrite: bool = False):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2057
2058
        Add a user input to the conversation for the next round. This populates the internal :obj:`new_user_input`
        field.
2059
2060

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2061
2062
2063
            text (:obj:`str`): The user input for the next conversation round.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not existing and unprocessed user input should be overwritten when this function is called.
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
        """
        if self.new_user_input:
            if overwrite:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format(
                        self.new_user_input, text
                    )
                )
                self.new_user_input = text
            else:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" new input ignored: "{}". '
                    "Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text)
                )
        else:
            self.new_user_input = text

    def mark_processed(self):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2083
2084
        Mark the conversation as processed (moves the content of :obj:`new_user_input` to :obj:`past_user_inputs`) and
        empties the :obj:`new_user_input` field.
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
        """
        if self.new_user_input:
            self.past_user_inputs.append(self.new_user_input)
        self.new_user_input = None

    def append_response(self, response: str):
        """
        Append a response to the list of generated responses.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2095
            response (:obj:`str`): The model generated response.
2096
2097
2098
2099
2100
        """
        self.generated_responses.append(response)

    def set_history(self, history: List[int]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2101
2102
        Updates the value of the history of the conversation. The history is represented by a list of :obj:`token_ids`.
        The history is used by the model to generate responses based on the previous conversation turns.
2103
2104

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2105
            history (:obj:`List[int]`): History of tokens provided and generated for this conversation.
2106
2107
2108
2109
2110
2111
2112
2113
        """
        self.history = history

    def __repr__(self):
        """
        Generates a string representation of the conversation.

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2114
            :obj:`str`:
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129

            Example:
            Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114
            user >> Going to the movies tonight - any suggestions?
            bot >> The Big Lebowski
        """
        output = "Conversation id: {} \n".format(self.uuid)
        for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
            output += "user >> {} \n".format(user_input)
            output += "bot >> {} \n".format(generated_response)
        if self.new_user_input is not None:
            output += "user >> {} \n".format(self.new_user_input)
        return output


Sylvain Gugger's avatar
Sylvain Gugger committed
2130
2131
2132
2133
2134
2135
2136
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        min_length_for_response (:obj:`int`, `optional`, defaults to 32):
            The minimum length (in number of tokens) for a response.
    """,
)
2137
2138
2139
2140
class ConversationalPipeline(Pipeline):
    """
    Multi-turn conversational pipeline.

Sylvain Gugger's avatar
Sylvain Gugger committed
2141
2142
2143
2144
2145
2146
2147
2148
    This conversational pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"conversational"`.

    The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task,
    currently: `'microsoft/DialoGPT-small'`, `'microsoft/DialoGPT-medium'`, `'microsoft/DialoGPT-large'`.
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=conversational>`__.

2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
    Usage::

        conversational_pipeline = pipeline("conversational")

        conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
        conversation_2 = Conversation("What's the last book you have read?")

        conversational_pipeline([conversation_1, conversation_2])

        conversation_1.add_user_input("Is it an action movie?")
        conversation_2.add_user_input("What is the genre of this book?")

        conversational_pipeline([conversation_1, conversation_2])
    """

    def __init__(self, min_length_for_response=32, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
        if self.tokenizer.pad_token_id is not None:
            self.pad_token_id = self.tokenizer.pad_token_id
        else:
            self.pad_token_id = self.tokenizer.eos_token_id
        self.min_length_for_response = min_length_for_response

    def __call__(
        self,
        conversations: Union[Conversation, List[Conversation]],
        clean_up_tokenization_spaces=True,
        **generate_kwargs
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
2180
2181
        Generate responses for the conversation(s) given as inputs.

2182
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2183
2184
2185
2186
2187
2188
2189
            conversations (a :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`):
                Conversations to generate responses for.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
2190
2191

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2192
2193
            :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`: Conversation(s) with
            updated generated responses for those containing a new user input.
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
        """

        # Input validation
        if isinstance(conversations, list):
            for conversation in conversations:
                assert isinstance(
                    conversation, Conversation
                ), "DialoguePipeline expects a Conversation or list of Conversations as an input"
                if conversation.new_user_input is None:
                    raise ValueError(
                        "Conversation with UUID {} does not contain new user input to process. "
                        "Add user inputs with the conversation's `add_user_input` method".format(
                            type(conversation.uuid)
                        )
                    )
            assert (
                self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
        elif isinstance(conversations, Conversation):
            conversations = [conversations]
        else:
            raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")

        with self.device_placement():

            inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
            histories = [conversation.history for conversation in conversations]
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            inputs = self._concat_inputs_history(inputs, histories, max_length)

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

            if input_length > 0.9 * max_length:
                logger.warning(
                    "Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
                    "You might consider trimming the early phase of the conversation".format(input_length, max_length)
                )
            generated_responses = self.model.generate(
                inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
            )

            cleaned_history = self._clean_padding_history(generated_responses)
            output = []
            for conversation_index, conversation in enumerate(conversations):
                conversation.mark_processed()
                conversation.generated_responses.append(
                    self.tokenizer.decode(
                        cleaned_history[conversation_index][input_length:],
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                )
                conversation.set_history(cleaned_history[conversation_index])
                output.append(conversation)
            if len(output) == 1:
                return output[0]
            else:
                return output

    def _parse_and_tokenize(self, *args, **kwargs):
        """
        Parse arguments and tokenize, adding an EOS token at the end of the user input
        """
        # Parse arguments
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
        for input in inputs:
            input.append(self.tokenizer.eos_token_id)
        return inputs

    def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
        """
        Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
        an input:
            - at the end of the concatenated history and new user input, so that all input to the model have the same
                length
            - at the end of the generated response, as some responses will be longer than others
        This method cleans up these padding token so that the history for each conversation is not impacted by the
        batching process.
        """
        outputs = []
        for sequence in generated_tensor:
            sequence_tokens = []
            is_previous_pad = False
            for token in sequence:
                if token == self.pad_token_id:
                    if is_previous_pad:
                        continue
                    else:
                        is_previous_pad = True
                else:
                    is_previous_pad = False
                if self.framework == "pt":
                    sequence_tokens.append(token.item())
                else:
                    sequence_tokens.append(int(token.numpy()))

            outputs.append(sequence_tokens)
        return outputs

    def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
        """
        Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
        """
        outputs = []
        for new_input, history in zip(inputs, histories):
            if history is not None:
                new_input = history + new_input
            if len(new_input) > max_length - self.min_length_for_response:
                cutoff_eos_index = 0
                while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
                    if cutoff_eos_index >= len(new_input):
                        break
                    cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
                    if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
                        break
                    else:
                        new_input = new_input[cutoff_eos_index + 1 :]
            outputs.append(new_input)
        max_len = max([len(item) for item in outputs])
        outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
        outputs = BatchEncoding(
            {"input_ids": outputs, "attention_mask": [1] * len(outputs)}, tensor_type=self.framework
        )
        return outputs


2326
# Register all the supported tasks here
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2327
SUPPORTED_TASKS = {
2328
2329
2330
2331
    "feature-extraction": {
        "impl": FeatureExtractionPipeline,
        "tf": TFAutoModel if is_tf_available() else None,
        "pt": AutoModel if is_torch_available() else None,
2332
        "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
2333
    },
2334
2335
2336
2337
2338
2339
    "sentiment-analysis": {
        "impl": TextClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {
2340
2341
                "pt": "distilbert-base-uncased-finetuned-sst-2-english",
                "tf": "distilbert-base-uncased-finetuned-sst-2-english",
2342
            },
2343
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2344
    },
2345
    "ner": {
2346
        "impl": TokenClassificationPipeline,
2347
2348
2349
2350
        "tf": TFAutoModelForTokenClassification if is_tf_available() else None,
        "pt": AutoModelForTokenClassification if is_torch_available() else None,
        "default": {
            "model": {
Julien Chaumond's avatar
Julien Chaumond committed
2351
2352
                "pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
                "tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
2353
            },
2354
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2355
    },
2356
2357
2358
2359
2360
    "question-answering": {
        "impl": QuestionAnsweringPipeline,
        "tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
        "pt": AutoModelForQuestionAnswering if is_torch_available() else None,
        "default": {
Lysandre's avatar
E231  
Lysandre committed
2361
            "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
2362
2363
        },
    },
Julien Chaumond's avatar
Julien Chaumond committed
2364
2365
2366
    "fill-mask": {
        "impl": FillMaskPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2367
        "pt": AutoModelForMaskedLM if is_torch_available() else None,
2368
        "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
Julien Chaumond's avatar
Julien Chaumond committed
2369
    },
2370
2371
    "summarization": {
        "impl": SummarizationPipeline,
2372
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2373
2374
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
        "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
2375
    },
2376
2377
2378
    "translation_en_to_fr": {
        "impl": TranslationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2379
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2380
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
2381
2382
2383
2384
    },
    "translation_en_to_de": {
        "impl": TranslationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2385
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2386
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
2387
2388
2389
2390
    },
    "translation_en_to_ro": {
        "impl": TranslationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2391
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2392
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
2393
    },
2394
2395
2396
    "text-generation": {
        "impl": TextGenerationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2397
        "pt": AutoModelForCausalLM if is_torch_available() else None,
2398
        "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
2399
    },
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
    "zero-shot-classification": {
        "impl": ZeroShotClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
        },
    },
2410
2411
2412
2413
2414
2415
    "conversational": {
        "impl": ConversationalPipeline,
        "tf": TFAutoModelForCausalLM if is_tf_available() else None,
        "pt": AutoModelForCausalLM if is_torch_available() else None,
        "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
    },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2416
2417
2418
}


2419
2420
2421
2422
2423
def pipeline(
    task: str,
    model: Optional = None,
    config: Optional[Union[str, PretrainedConfig]] = None,
    tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
2424
    framework: Optional[str] = None,
2425
2426
    **kwargs
) -> Pipeline:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2427
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
2428
    Utility factory method to build a :class:`~transformers.Pipeline`.
Lysandre Debut's avatar
Lysandre Debut committed
2429

Sylvain Gugger's avatar
Sylvain Gugger committed
2430
    Pipelines are made of:
Lysandre Debut's avatar
Lysandre Debut committed
2431

Sylvain Gugger's avatar
Sylvain Gugger committed
2432
2433
2434
        - A :doc:`tokenizer <tokenizer>` in charge of mapping raw textual input to token.
        - A :doc:`model <model>` to make predictions from the inputs.
        - Some (optional) post processing for enhancing model's output.
Lysandre Debut's avatar
Lysandre Debut committed
2435
2436
2437
2438
2439

    Args:
        task (:obj:`str`):
            The task defining which pipeline will be returned. Currently accepted tasks are:

Sylvain Gugger's avatar
Sylvain Gugger committed
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
            - :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
            - :obj:`"sentiment-analysis"`: will return a :class:`~transformers.TextClassificationPipeline`.
            - :obj:`"ner"`: will return a :class:`~transformers.TokenClassificationPipeline`.
            - :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
            - :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
            - :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.
            - :obj:`"translation_xx_to_yy"`: will return a :class:`~transformers.TranslationPipeline`.
            - :obj:`"text-generation"`: will return a :class:`~transformers.TextGenerationPipeline`.
            - :obj:`"conversation"`: will return a :class:`~transformers.ConversationalPipeline`.
        model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`):
            The model that will be used by the pipeline to make predictions. This can be a model identifier or an
            actual instance of a pretrained model inheriting from :class:`~transformers.PreTrainedModel` (for PyTorch)
            or :class:`~transformers.TFPreTrainedModel` (for TensorFlow).

            If not provided, the default for the :obj:`task` will be loaded.
        config (:obj:`str` or :obj:`~transformers.PretrainedConfig`, `optional`):
            The configuration that will be used by the pipeline to instantiate the model. This can be a model
            identifier or an actual pretrained model configuration inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
2458
2459
            :class:`~transformers.PretrainedConfig`.

Sylvain Gugger's avatar
Sylvain Gugger committed
2460
2461
2462
2463
            If not provided, the default for the :obj:`task` will be loaded.
        tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`):
            The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
            identifier or an actual pretrained tokenizer inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
2464
2465
            :class:`~transformers.PreTrainedTokenizer`.

Sylvain Gugger's avatar
Sylvain Gugger committed
2466
2467
2468
2469
            If not provided, the default for the :obj:`task` will be loaded.
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
2470
2471

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
2472
2473
2474
2475
2476
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        kwargs:
            Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
            corresponding pipeline class for possible values).
Lysandre Debut's avatar
Lysandre Debut committed
2477
2478

    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2479
        :class:`~transformers.Pipeline`: A suitable pipeline for the task.
Lysandre Debut's avatar
Lysandre Debut committed
2480
2481
2482
2483
2484
2485

    Examples::

        from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer

        # Sentiment analysis pipeline
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2486
        pipeline('sentiment-analysis')
Lysandre Debut's avatar
Lysandre Debut committed
2487
2488

        # Question answering pipeline, specifying the checkpoint identifier
2489
        pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased')
Lysandre Debut's avatar
Lysandre Debut committed
2490
2491
2492
2493
2494

        # Named entity recognition pipeline, passing in a specific model and tokenizer
        model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
        tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        pipeline('ner', model=model, tokenizer=tokenizer)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2495
    """
2496
    # Retrieve the task
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2497
2498
2499
    if task not in SUPPORTED_TASKS:
        raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))

2500
    framework = framework or get_framework(model)
2501

Morgan Funtowicz's avatar
Morgan Funtowicz committed
2502
    targeted_task = SUPPORTED_TASKS[task]
2503
    task_class, model_class = targeted_task["impl"], targeted_task[framework]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2504

2505
    # Use default model/config/tokenizer for the task if no model is provided
2506
    if model is None:
2507
        model = targeted_task["default"]["model"][framework]
2508

2509
2510
    # Try to infer tokenizer from model or config name (if provided as str)
    if tokenizer is None:
2511
        if isinstance(model, str):
2512
            tokenizer = model
2513
        elif isinstance(config, str):
2514
2515
2516
            tokenizer = config
        else:
            # Impossible to guest what is the right tokenizer here
2517
2518
            raise Exception(
                "Impossible to guess which tokenizer to use. "
2519
                "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
2520
            )
2521

Lysandre Debut's avatar
Lysandre Debut committed
2522
    modelcard = None
2523
    # Try to infer modelcard from model or config name (if provided as str)
Lysandre Debut's avatar
Lysandre Debut committed
2524
2525
2526
2527
    if isinstance(model, str):
        modelcard = model
    elif isinstance(config, str):
        modelcard = config
2528
2529

    # Instantiate tokenizer if needed
2530
2531
2532
2533
2534
2535
    if isinstance(tokenizer, (str, tuple)):
        if isinstance(tokenizer, tuple):
            # For tuple we have (tokenizer name, {kwargs})
            tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
        else:
            tokenizer = AutoTokenizer.from_pretrained(tokenizer)
2536
2537
2538
2539

    # Instantiate config if needed
    if isinstance(config, str):
        config = AutoConfig.from_pretrained(config)
2540

thomwolf's avatar
thomwolf committed
2541
2542
2543
2544
    # Instantiate modelcard if needed
    if isinstance(modelcard, str):
        modelcard = ModelCard.from_pretrained(modelcard)

2545
    # Instantiate model if needed
2546
    if isinstance(model, str):
2547
2548
        # Handle transparent TF/PT model conversion
        model_kwargs = {}
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
        if framework == "pt" and model.endswith(".h5"):
            model_kwargs["from_tf"] = True
            logger.warning(
                "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
                "Trying to load the model with PyTorch."
            )
        elif framework == "tf" and model.endswith(".bin"):
            model_kwargs["from_pt"] = True
            logger.warning(
                "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
                "Trying to load the model with Tensorflow."
            )
2561
        model = model_class.from_pretrained(model, config=config, **model_kwargs)
2562

2563
    return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)