pipelines.py 110 KB
Newer Older
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

Morgan Funtowicz's avatar
Morgan Funtowicz committed
16

17
18
import csv
import json
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import logging
Morgan Funtowicz's avatar
Morgan Funtowicz committed
20
import os
21
import pickle
Aymeric Augustin's avatar
Aymeric Augustin committed
22
import sys
23
import uuid
Morgan Funtowicz's avatar
Morgan Funtowicz committed
24
from abc import ABC, abstractmethod
25
from contextlib import contextmanager
26
from itertools import chain
27
from os.path import abspath, exists
28
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
29
from uuid import UUID
Morgan Funtowicz's avatar
Morgan Funtowicz committed
30
31
32

import numpy as np

33
from .configuration_auto import AutoConfig
34
35
from .configuration_utils import PretrainedConfig
from .data import SquadExample, squad_convert_examples_to_features
Sylvain Gugger's avatar
Sylvain Gugger committed
36
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available
37
38
39
40
from .modelcard import ModelCard
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
41
from .tokenization_utils_base import BatchEncoding, PaddingStrategy
Morgan Funtowicz's avatar
Morgan Funtowicz committed
42

Aymeric Augustin's avatar
Aymeric Augustin committed
43

Morgan Funtowicz's avatar
Morgan Funtowicz committed
44
if is_tf_available():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
45
    import tensorflow as tf
46

47
    from .modeling_tf_auto import (
48
49
50
51
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
        TF_MODEL_WITH_LM_HEAD_MAPPING,
52
        TFAutoModel,
53
        TFAutoModelForCausalLM,
54
        TFAutoModelForQuestionAnswering,
55
        TFAutoModelForSequenceClassification,
56
        TFAutoModelForTokenClassification,
Julien Chaumond's avatar
Julien Chaumond committed
57
        TFAutoModelWithLMHead,
58
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
59
60
61

if is_torch_available():
    import torch
62

63
    from .modeling_auto import (
64
65
66
67
68
        MODEL_FOR_MASKED_LM_MAPPING,
        MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
        MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
69
        AutoModel,
70
71
        AutoModelForCausalLM,
        AutoModelForMaskedLM,
72
73
74
75
        AutoModelForQuestionAnswering,
        AutoModelForSeq2SeqLM,
        AutoModelForSequenceClassification,
        AutoModelForTokenClassification,
76
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
77

78
79
if TYPE_CHECKING:
    from .modeling_tf_utils import TFPreTrainedModel
80
    from .modeling_utils import PreTrainedModel
81

Morgan Funtowicz's avatar
Morgan Funtowicz committed
82

83
84
logger = logging.getLogger(__name__)

85

thomwolf's avatar
thomwolf committed
86
def get_framework(model=None):
Sylvain Gugger's avatar
Sylvain Gugger committed
87
88
89
90
91
92
93
    """
    Select framework (TensorFlow or PyTorch) to use.

    Args:
        model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`, `optional`):
            If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
            the model name). If no specific model is provided, defaults to using PyTorch.
94
    """
thomwolf's avatar
thomwolf committed
95
    if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
Julien Chaumond's avatar
Julien Chaumond committed
96
        # Both framework are available but the user supplied a model class instance.
thomwolf's avatar
thomwolf committed
97
        # Try to guess which framework to use from the model classname
98
        framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
99
    elif not is_tf_available() and not is_torch_available():
Aymeric Augustin's avatar
Aymeric Augustin committed
100
        raise RuntimeError(
101
102
103
104
            "At least one of TensorFlow 2.0 or PyTorch should be installed. "
            "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
            "To install PyTorch, read the instructions at https://pytorch.org/."
        )
105
    else:
106
        # framework = 'tf' if is_tf_available() else 'pt'
107
        framework = "pt" if is_torch_available() else "tf"
thomwolf's avatar
thomwolf committed
108
109
    return framework

110

111
112
class PipelineException(Exception):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
113
114
115
116
117
118
    Raised by a :class:`~transformers.Pipeline` when handling __call__.

    Args:
        task (:obj:`str`): The task of the pipeline.
        model (:obj:`str`): The model used by the pipeline.
        reason (:obj:`str`): The error message to display.
119
120
121
122
123
124
125
126
127
    """

    def __init__(self, task: str, model: str, reason: str):
        super().__init__(reason)

        self.task = task
        self.model = model


128
129
class ArgumentHandler(ABC):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
130
    Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
131
    """
132

133
134
135
    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplementedError()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
136
137


138
139
class DefaultArgumentHandler(ArgumentHandler):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
140
    Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`.
141
    """
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    @staticmethod
    def handle_kwargs(kwargs: Dict) -> List:
        if len(kwargs) == 1:
            output = list(kwargs.values())
        else:
            output = list(chain(kwargs.values()))

        return DefaultArgumentHandler.handle_args(output)

    @staticmethod
    def handle_args(args: Sequence[Any]) -> List[str]:

        # Only one argument, let's do case by case
        if len(args) == 1:
            if isinstance(args[0], str):
158
                return [args[0]]
159
160
161
162
163
164
            elif not isinstance(args[0], list):
                return list(args)
            else:
                return args[0]

        # Multiple arguments (x1, x2, ...)
165
        elif len(args) > 1:
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            if all([isinstance(arg, str) for arg in args]):
                return list(args)

            # If not instance of list, then it should instance of iterable
            elif isinstance(args, Iterable):
                return list(chain.from_iterable(chain(args)))
            else:
                raise ValueError(
                    "Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
                )
        else:
            return []

    def __call__(self, *args, **kwargs):
        if len(kwargs) > 0 and len(args) > 0:
            raise ValueError("Pipeline cannot handle mixed args and kwargs")

        if len(kwargs) > 0:
            return DefaultArgumentHandler.handle_kwargs(kwargs)
        else:
            return DefaultArgumentHandler.handle_args(args)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
187
188


189
class PipelineDataFormat:
190
191
192
    """
    Base class for all the pipeline supported data format both for reading and writing.
    Supported data formats currently includes:
Sylvain Gugger's avatar
Sylvain Gugger committed
193
194
195
    - JSON
    - CSV
    - stdin/stdout (pipe)
196

Sylvain Gugger's avatar
Sylvain Gugger committed
197
198
199
200
201
202
203
204
205
    :obj:`PipelineDataFormat` also includes some utilities to work with multi-columns like mapping from datasets
    columns to pipelines keyword arguments through the :obj:`dataset_kwarg_1=dataset_column_1` format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
206
    """
207
208

    SUPPORTED_FORMATS = ["json", "csv", "pipe"]
209

210
    def __init__(
Sylvain Gugger's avatar
Sylvain Gugger committed
211
        self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite: bool = False,
212
    ):
thomwolf's avatar
thomwolf committed
213
214
        self.output_path = output_path
        self.input_path = input_path
215
        self.column = column.split(",") if column is not None else [""]
216
217
218
        self.is_multi_columns = len(self.column) > 1

        if self.is_multi_columns:
219
            self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column]
220

thomwolf's avatar
thomwolf committed
221
        if output_path is not None and not overwrite:
thomwolf's avatar
thomwolf committed
222
            if exists(abspath(self.output_path)):
223
                raise OSError("{} already exists on disk".format(self.output_path))
224

thomwolf's avatar
thomwolf committed
225
226
        if input_path is not None:
            if not exists(abspath(self.input_path)):
227
                raise OSError("{} doesnt exist on disk".format(self.input_path))
228
229
230
231
232
233

    @abstractmethod
    def __iter__(self):
        raise NotImplementedError()

    @abstractmethod
Sylvain Gugger's avatar
Sylvain Gugger committed
234
    def save(self, data: Union[dict, List[dict]]):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
235
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
236
237
238
239
240
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
241
        """
242
243
        raise NotImplementedError()

244
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
245
246
        """
        Save the provided data object as a pickle-formatted binary data on the disk.
Sylvain Gugger's avatar
Sylvain Gugger committed
247
248
249
250
251
252

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.

        Returns:
            :obj:`str`: Path where the data has been saved.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
253
        """
thomwolf's avatar
thomwolf committed
254
        path, _ = os.path.splitext(self.output_path)
255
        binary_path = os.path.extsep.join((path, "pickle"))
256

257
        with open(binary_path, "wb+") as f_output:
258
259
260
261
            pickle.dump(data, f_output)

        return binary_path

262
    @staticmethod
263
    def from_str(
264
        format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
Sylvain Gugger's avatar
Sylvain Gugger committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    ) -> "PipelineDataFormat":
        """
        Creates an instance of the right subclass of :class:`~transformers.pipelines.PipelineDataFormat` depending
        on :obj:`format`.

        Args:
            format: (:obj:`str`):
                The format of the desired pipeline. Acceptable values are :obj:`"json"`, :obj:`"csv"` or :obj:`"pipe"`.
            output_path (:obj:`str`, `optional`):
                Where to save the outgoing data.
            input_path (:obj:`str`, `optional`):
                Where to look for the input data.
            column (:obj:`str`, `optional`):
                The column to read.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to overwrite the :obj:`output_path`.

        Returns:
            :class:`~transformers.pipelines.PipelineDataFormat`: The proper data format.
        """
285
        if format == "json":
thomwolf's avatar
thomwolf committed
286
            return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
287
        elif format == "csv":
thomwolf's avatar
thomwolf committed
288
            return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
289
        elif format == "pipe":
thomwolf's avatar
thomwolf committed
290
            return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
291
        else:
292
            raise KeyError("Unknown reader {} (Available reader are json/csv/pipe)".format(format))
293
294
295


class CsvPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
296
297
298
299
300
301
302
303
304
305
306
    """
    Support for pipelines using CSV data format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

307
308
309
    def __init__(
        self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
    ):
thomwolf's avatar
thomwolf committed
310
        super().__init__(output_path, input_path, column, overwrite=overwrite)
311
312

    def __iter__(self):
313
        with open(self.input_path, "r") as f:
314
315
316
317
318
            reader = csv.DictReader(f)
            for row in reader:
                if self.is_multi_columns:
                    yield {k: row[c] for k, c in self.column}
                else:
319
                    yield row[self.column[0]]
320
321

    def save(self, data: List[dict]):
Sylvain Gugger's avatar
Sylvain Gugger committed
322
323
324
325
326
327
328
        """
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`List[dict]`): The data to store.
        """
329
        with open(self.output_path, "w") as f:
330
331
332
333
334
335
336
            if len(data) > 0:
                writer = csv.DictWriter(f, list(data[0].keys()))
                writer.writeheader()
                writer.writerows(data)


class JsonPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
337
338
339
340
341
342
343
344
345
346
347
    """
    Support for pipelines using JSON file format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

348
349
350
    def __init__(
        self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
    ):
thomwolf's avatar
thomwolf committed
351
        super().__init__(output_path, input_path, column, overwrite=overwrite)
352

353
        with open(input_path, "r") as f:
354
355
356
357
358
359
360
            self._entries = json.load(f)

    def __iter__(self):
        for entry in self._entries:
            if self.is_multi_columns:
                yield {k: entry[c] for k, c in self.column}
            else:
361
                yield entry[self.column[0]]
362
363

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
364
365
366
367
368
369
        """
        Save the provided data object in a json file.

        Args:
            data (:obj:`dict`): The data to store.
        """
370
        with open(self.output_path, "w") as f:
371
372
373
            json.dump(data, f)


Morgan Funtowicz's avatar
Morgan Funtowicz committed
374
375
376
377
378
379
class PipedPipelineDataFormat(PipelineDataFormat):
    """
    Read data from piped input to the python process.
    For multi columns data, columns should separated by \t

    If columns are provided, then the output will be a dictionary with {column_x: value_x}
Sylvain Gugger's avatar
Sylvain Gugger committed
380
381
382
383
384
385
386

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
387
    """
388

Morgan Funtowicz's avatar
Morgan Funtowicz committed
389
390
391
    def __iter__(self):
        for line in sys.stdin:
            # Split for multi-columns
392
            if "\t" in line:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
393

394
                line = line.split("\t")
Morgan Funtowicz's avatar
Morgan Funtowicz committed
395
396
397
398
399
400
401
402
403
404
405
                if self.column:
                    # Dictionary to map arguments
                    yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
                else:
                    yield tuple(line)

            # No dictionary to map arguments
            else:
                yield line

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
406
407
408
409
410
411
        """
        Print the data.

        Args:
            data (:obj:`dict`): The data to store.
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
412
413
        print(data)

414
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
thomwolf's avatar
thomwolf committed
415
        if self.output_path is None:
416
            raise KeyError(
417
418
                "When using piped input on pipeline outputting large object requires an output file path. "
                "Please provide such output path through --output argument."
419
420
421
422
            )

        return super().save_binary(data)

Morgan Funtowicz's avatar
Morgan Funtowicz committed
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

class _ScikitCompat(ABC):
    """
    Interface layer for the Scikit and Keras compatibility.
    """

    @abstractmethod
    def transform(self, X):
        raise NotImplementedError()

    @abstractmethod
    def predict(self, X):
        raise NotImplementedError()


Sylvain Gugger's avatar
Sylvain Gugger committed
438
PIPELINE_INIT_ARGS = r"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
439
    Arguments:
440
441
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
442
443
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
444
445
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
446
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
447
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
448
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
449
450
451
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
452
453

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
454
455
456
457
458
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
459
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
460
461
        device (:obj:`int`, `optional`, defaults to -1):
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model
Morgan Funtowicz's avatar
Morgan Funtowicz committed
462
            on the associated CUDA device id.
Lysandre Debut's avatar
Lysandre Debut committed
463
        binary_output (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
464
465
            Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
466

Lysandre Debut's avatar
Lysandre Debut committed
467

Sylvain Gugger's avatar
Sylvain Gugger committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
@add_end_docstrings(PIPELINE_INIT_ARGS)
class Pipeline(_ScikitCompat):
    """
    The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
    different pipelines.

    Base class implementing pipelined operations.
    Pipeline workflow is defined as a sequence of the following operations:

        Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output

    Pipeline supports running on CPU or GPU through the device argument (see below).

    Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` )
    output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
    provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the
    pickle format.
485
    """
thomwolf's avatar
thomwolf committed
486
487
488

    default_input_names = None

489
490
    def __init__(
        self,
491
492
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
493
        modelcard: Optional[ModelCard] = None,
494
        framework: Optional[str] = None,
495
        task: str = "",
496
497
498
499
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
    ):
500

thomwolf's avatar
thomwolf committed
501
        if framework is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
502
            framework = get_framework(model)
thomwolf's avatar
thomwolf committed
503

504
        self.task = task
505
506
        self.model = model
        self.tokenizer = tokenizer
507
        self.modelcard = modelcard
thomwolf's avatar
thomwolf committed
508
        self.framework = framework
509
        self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
510
        self.binary_output = binary_output
511
512
        self._args_parser = args_parser or DefaultArgumentHandler()

513
        # Special handling
514
515
        if self.framework == "pt" and self.device.type == "cuda":
            self.model = self.model.to(self.device)
516

517
518
519
520
521
        # Update config with task specific parameters
        task_specific_params = self.model.config.task_specific_params
        if task_specific_params is not None and task in task_specific_params:
            self.model.config.update(task_specific_params.get(task))

Sylvain Gugger's avatar
Sylvain Gugger committed
522
    def save_pretrained(self, save_directory: str):
523
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
524
525
526
527
528
        Save the pipeline's model and tokenizer.

        Args:
            save_directory (:obj:`str`):
                A path to the directory where to saved. It will be created if it doesn't exist.
529
        """
530
531
        if os.path.isfile(save_directory):
            logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
532
            return
533
        os.makedirs(save_directory, exist_ok=True)
534
535
536

        self.model.save_pretrained(save_directory)
        self.tokenizer.save_pretrained(save_directory)
537
538
        if self.modelcard is not None:
            self.modelcard.save_pretrained(save_directory)
539
540

    def transform(self, X):
541
542
543
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
544
545
546
        return self(X=X)

    def predict(self, X):
547
548
549
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
550
        return self(X=X)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
551

552
553
    @contextmanager
    def device_placement(self):
554
555
        """
        Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
Sylvain Gugger's avatar
Sylvain Gugger committed
556

557
558
        Returns:
            Context manager
Sylvain Gugger's avatar
Sylvain Gugger committed
559
560
561
562
563
564
565
566

        Examples::

            # Explicitly ask for tensor allocation on CUDA device :0
            pipe = pipeline(..., device=0)
            with pipe.device_placement():
                # Every framework specific tensor allocation will be done on the request device
                output = pipe(...)
567
        """
568
569
        if self.framework == "tf":
            with tf.device("/CPU:0" if self.device == -1 else "/device:GPU:{}".format(self.device)):
570
571
                yield
        else:
572
            if self.device.type == "cuda":
573
                torch.cuda.set_device(self.device)
574

575
            yield
576

577
578
579
    def ensure_tensor_on_device(self, **inputs):
        """
        Ensure PyTorch tensors are on the specified device.
Sylvain Gugger's avatar
Sylvain Gugger committed
580
581
582
583
584
585

        Args:
            inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.

        Return:
            :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
586
587
588
        """
        return {name: tensor.to(self.device) for name, tensor in inputs.items()}

Sylvain Gugger's avatar
Sylvain Gugger committed
589
    def check_model_type(self, supported_models: Union[List[str], dict]):
590
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
591
592
593
594
595
        Check if the model class is in supported by the pipeline.

        Args:
            supported_models (:obj:`List[str]` or :obj:`dict`):
                The list of models supported by the pipeline, or a dictionary with model class values.
596
597
598
599
600
601
602
603
604
605
        """
        if not isinstance(supported_models, list):  # Create from a model mapping
            supported_models = [item[1].__name__ for item in supported_models.items()]
        if self.model.__class__.__name__ not in supported_models:
            raise PipelineException(
                self.task,
                self.model.base_model_prefix,
                f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
            )

606
    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
607
608
609
        """
        Parse arguments and tokenize
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
610
        # Parse arguments
611
        inputs = self._args_parser(*args, **kwargs)
612
613
        inputs = self.tokenizer(
            inputs, add_special_tokens=add_special_tokens, return_tensors=self.framework, padding=padding,
614
        )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
615

Julien Chaumond's avatar
Julien Chaumond committed
616
617
        return inputs

618
619
    def __call__(self, *args, **kwargs):
        inputs = self._parse_and_tokenize(*args, **kwargs)
620
        return self._forward(inputs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
621

Julien Chaumond's avatar
Julien Chaumond committed
622
    def _forward(self, inputs, return_tensors=False):
623
624
625
626
        """
        Internal framework specific forward dispatching.
        Args:
            inputs: dict holding all the keyworded arguments for required by the model forward method.
Julien Chaumond's avatar
Julien Chaumond committed
627
            return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array.
628
629
630
        Returns:
            Numpy array
        """
631
632
633
634
        # Encode for forward
        with self.device_placement():
            if self.framework == "tf":
                # TODO trace model
Funtowicz Morgan's avatar
Funtowicz Morgan committed
635
                predictions = self.model(inputs.data, training=False)[0]
636
637
638
639
            else:
                with torch.no_grad():
                    inputs = self.ensure_tensor_on_device(**inputs)
                    predictions = self.model(**inputs)[0].cpu()
640

Julien Chaumond's avatar
Julien Chaumond committed
641
642
643
644
        if return_tensors:
            return predictions
        else:
            return predictions.numpy()
645
646


Sylvain Gugger's avatar
Sylvain Gugger committed
647
# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output`
648
class FeatureExtractionPipeline(Pipeline):
649
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
650
651
    Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
    transformer, which can be used as features in downstream tasks.
Lysandre Debut's avatar
Lysandre Debut committed
652

Sylvain Gugger's avatar
Sylvain Gugger committed
653
654
    This feature extraction pipeline can currently be loaded from :func:`~transformers.pipeline` using the task
    identifier: :obj:`"feature-extraction"`.
Lysandre Debut's avatar
Lysandre Debut committed
655
656
657
658
659

    All models may be used for this pipeline. See a list of all models, including community-contributed models on
    `huggingface.co/models <https://huggingface.co/models>`__.

    Arguments:
660
661
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
662
663
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
664
665
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
666
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
667
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
668
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
669
670
671
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
672
673

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
674
675
676
677
678
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
679
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
680
681
        device (:obj:`int`, `optional`, defaults to -1):
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model
Lysandre Debut's avatar
Lysandre Debut committed
682
            on the associated CUDA device id.
683
    """
684

685
686
    def __init__(
        self,
687
688
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
689
        modelcard: Optional[ModelCard] = None,
690
691
692
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
693
        task: str = "",
694
695
696
697
698
699
700
701
702
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
703
            task=task,
704
        )
705

706
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
707
708
709
710
711
712
713
714
715
        """
        Extract the features of the input(s).

        Args:
            args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of texts) to get the features of.

        Return:
            A nested list of :obj:`float`: The features computed by the model.
        """
716
        return super().__call__(*args, **kwargs).tolist()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
717
718


Sylvain Gugger's avatar
Sylvain Gugger committed
719
@add_end_docstrings(PIPELINE_INIT_ARGS)
720
721
class TextGenerationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
722
723
    Language generation pipeline using any :obj:`ModelWithLMHead`. This pipeline predicts the words that will follow a
    specified text prompt.
724

Sylvain Gugger's avatar
Sylvain Gugger committed
725
726
    This language generation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"text-generation"`.
727

Sylvain Gugger's avatar
Sylvain Gugger committed
728
729
    The models that this pipeline can use are models that have been trained with an autoregressive language modeling
    objective, which includes the uni-directional models in the library (e.g. gpt2).
730
731
732
733
734
735
736
    See the list of available community models on
    `huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
    """

    # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
    # in https://github.com/rusiaaman/XLNet-gen#methodology
    # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
737

738
739
740
741
742
743
744
745
746
    PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
    (except for Alexei and Maria) are discovered.
    The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
    remainder of the story. 1883 Western Siberia,
    a young Grigori Rasputin is asked by his father and a group of men to perform magic.
    Rasputin has a vision and denounces one of the men as a horse thief. Although his
    father initially slaps him for making such an accusation, Rasputin watches as the
    man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
    the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
747
    with people, even a bishop, begging for his blessing. """
748

749
750
751
752
753
754
755
756
757
758
759
760
761
762
    ALLOWED_MODELS = [
        "XLNetLMHeadModel",
        "TransfoXLLMHeadModel",
        "ReformerModelWithLMHead",
        "GPT2LMHeadModel",
        "OpenAIGPTLMHeadModel",
        "CTRLLMHeadModel",
        "TFXLNetLMHeadModel",
        "TFTransfoXLLMHeadModel",
        "TFGPT2LMHeadModel",
        "TFOpenAIGPTLMHeadModel",
        "TFCTRLLMHeadModel",
    ]

763
764
765
766
767
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.check_model_type(self.ALLOWED_MODELS)

768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
    # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments

    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
        """
        Parse arguments and tokenize
        """
        # Parse arguments
        if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
            tokenizer_kwargs = {"add_space_before_punct_symbol": True}
        else:
            tokenizer_kwargs = {}
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer(
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            **tokenizer_kwargs,
        )

        return inputs

790
    def __call__(
791
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
792
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
        """
        Complete the prompt(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several prompts (or one list of prompts) to complete.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
812

Sylvain Gugger's avatar
Sylvain Gugger committed
813
814
815
816
            - **generated_text** (:obj:`str`, present when ``return_text=True``) -- The generated text.
            - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the generated text.
        """
817
        text_inputs = self._args_parser(*args)
818
819
820
821
822

        results = []
        for prompt_text in text_inputs:
            # Manage correct placement of the tensors
            with self.device_placement():
823
824
825
826
827
828
                if self.model.__class__.__name__ in [
                    "XLNetLMHeadModel",
                    "TransfoXLLMHeadModel",
                    "TFXLNetLMHeadModel",
                    "TFTransfoXLLMHeadModel",
                ]:
829
830
831
832
833
834
835
836
837
838
                    # For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
                    padding_text = self.PADDING_TEXT + self.tokenizer.eos_token
                    padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False)
                    # This impacts max_length and min_length argument that need adjusting.
                    padding_length = padding["input_ids"].shape[-1]
                    if "max_length" in generate_kwargs and generate_kwargs["max_length"] is not None:
                        generate_kwargs["max_length"] += padding_length
                    if "min_length" in generate_kwargs and generate_kwargs["min_length"] is not None:
                        generate_kwargs["min_length"] += padding_length

839
                    inputs = self._parse_and_tokenize(
840
                        padding_text + prompt_text, padding=False, add_special_tokens=False
841
                    )
842
                else:
843
                    inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False)
844

845
846
847
848
849
850
                # set input_ids to None to allow empty prompt
                if inputs["input_ids"].shape[-1] == 0:
                    inputs["input_ids"] = None
                    inputs["attention_mask"] = None

                if self.framework == "pt" and inputs["input_ids"] is not None:
851
852
853
854
855
856
                    inputs = self.ensure_tensor_on_device(**inputs)

                input_ids = inputs["input_ids"]

                # Ensure that batch size = 1 (batch generation not allowed for now)
                assert (
857
                    input_ids is None or input_ids.shape[0] == 1
858
859
860
861
862
863
                ), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."

                output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs)  # BS x SL

            result = []
            for generated_sequence in output_sequences:
864
865
                if self.framework == "pt" and generated_sequence is not None:
                    generated_sequence = generated_sequence.cpu()
866
                generated_sequence = generated_sequence.numpy().tolist()
867
868
869
870
871
872
873
874
875
876
877
878
                record = {}
                if return_tensors:
                    record["generated_token_ids"] = generated_sequence
                if return_text:
                    # Decode text
                    text = self.tokenizer.decode(
                        generated_sequence,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )

                    # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
879
880
881
882
883
884
885
886
887
888
889
890
                    if input_ids is None:
                        prompt_length = 0
                    else:
                        prompt_length = len(
                            self.tokenizer.decode(
                                input_ids[0],
                                skip_special_tokens=True,
                                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                            )
                        )

                    record["generated_text"] = prompt_text + text[prompt_length:]
891
892
893
894
895
896
897
898
899
900

                result.append(record)
            results += [result]

        if len(results) == 1:
            return results[0]

        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
901
902
903
904
905
906
907
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        return_all_scores (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether to return all prediction scores or just the one of the predicted class.
    """,
)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
908
class TextClassificationPipeline(Pipeline):
909
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
910
911
    Text classification pipeline using any :obj:`ModelForSequenceClassification`. See the
    `sequence classification examples <../task_summary.html#sequence-classification>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
912

Sylvain Gugger's avatar
Sylvain Gugger committed
913
914
915
    This text classification pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"sentiment-analysis"` (for classifying sequences according to positive or negative
    sentiments).
Lysandre Debut's avatar
Lysandre Debut committed
916
917

    The models that this pipeline can use are models that have been fine-tuned on a sequence classification task.
918
919
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=text-classification>`__.
920
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
921

922
923
924
    def __init__(self, return_all_scores: bool = False, **kwargs):
        super().__init__(**kwargs)

925
926
927
928
929
930
        self.check_model_type(
            TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
        )

931
932
        self.return_all_scores = return_all_scores

933
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
        """
        Classify the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several textts (or one list of prompts) to classify.

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
            following keys:

            - **label** (:obj:`str`) -- The label predicted.
            - **score** (:obj:`float`) -- The corresponding probability.

            If ``self.return_all_scores=True``, one such dictionary is returned per label.
        """
950
        outputs = super().__call__(*args, **kwargs)
Zhiyu Lin's avatar
Zhiyu Lin committed
951
        scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
952
953
        if self.return_all_scores:
            return [
954
                [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(item)]
955
956
957
958
959
960
                for item in scores
            ]
        else:
            return [
                {"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores
            ]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
961
962


963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
class ZeroShotClassificationArgumentHandler(ArgumentHandler):
    """
    Handles arguments for zero-shot for text classification by turning each possible label into an NLI
    premise/hypothesis pair.
    """

    def _parse_labels(self, labels):
        if isinstance(labels, str):
            labels = [label.strip() for label in labels.split(",")]
        return labels

    def __call__(self, sequences, labels, hypothesis_template):
        if len(labels) == 0 or len(sequences) == 0:
            raise ValueError("You must include at least one label and at least one sequence.")
        if hypothesis_template.format(labels[0]) == hypothesis_template:
            raise ValueError(
                (
                    'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
                    "Make sure the passed template includes formatting syntax such as {{}} where the label should go."
                ).format(hypothesis_template)
            )

        if isinstance(sequences, str):
            sequences = [sequences]
        labels = self._parse_labels(labels)

        sequence_pairs = []
        for sequence in sequences:
            sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])

        return sequence_pairs


Sylvain Gugger's avatar
Sylvain Gugger committed
996
@add_end_docstrings(PIPELINE_INIT_ARGS)
997
998
class ZeroShotClassificationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
999
1000
    NLI-based zero-shot classification pipeline using a :obj:`ModelForSequenceClassification` trained on NLI (natural
    language inference) tasks.
1001
1002

    Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
Sylvain Gugger's avatar
Sylvain Gugger committed
1003
    pair and passed to the pretrained model. Then, the logit for `entailment` is taken as the logit for the
1004
1005
1006
    candidate label being valid. Any NLI model can be used as long as the first output logit corresponds to
    `contradiction` and the last to `entailment`.

Sylvain Gugger's avatar
Sylvain Gugger committed
1007
1008
    This NLI pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"zero-shot-classification"`.
1009

Sylvain Gugger's avatar
Sylvain Gugger committed
1010
    The models that this pipeline can use are models that have been fine-tuned on an NLI task.
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?search=nli>`__.
    """

    def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
        super().__init__(*args, args_parser=args_parser, **kwargs)

    def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
        """
        Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
        """
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer(
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            truncation="only_first",
        )

        return inputs

    def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1035
        Classify the sequence(s) given as inputs.
1036
1037

        Args:
1038
            sequences (:obj:`str` or :obj:`List[str]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1039
                The sequence(s) to classify, will be truncated if the model input is too large.
1040
            candidate_labels (:obj:`str` or :obj:`List[str]`):
1041
1042
                The set of possible class labels to classify each sequence into. Can be a single label, a string of
                comma-separated labels, or a list of labels.
1043
            hypothesis_template (:obj:`str`, `optional`, defaults to :obj:`"This example is {}."`):
1044
1045
                The template used to turn each label into an NLI-style hypothesis. This template must include a {}
                or similar syntax for the candidate label to be inserted into the template. For example, the default
Sylvain Gugger's avatar
Sylvain Gugger committed
1046
1047
1048
1049
                template is :obj:`"This example is {}."` With the candidate label :obj:`"sports"`, this would be fed
                into the model like :obj:`"<cls> sequence to classify <sep> This example is sports . <sep>"`. The
                default template works well in many cases, but it may be worthwhile to experiment with different
                templates depending on the task setting.
1050
            multi_class (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1051
1052
1053
                Whether or not multiple candidate labels can be true. If :obj:`False`, the scores are normalized
                such that the sum of the label likelihoods for each sequence is 1. If :obj:`True`, the labels are
                considered independent and probabilities are normalized for each candidate by doing a softmax of
1054
                the entailment score vs. the contradiction score.
1055

Sylvain Gugger's avatar
Sylvain Gugger committed
1056
1057
1058
1059
1060
1061
        Return:
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:

            - **sequence** (:obj:`str`) -- The sequence for which this is the output.
            - **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
1062
            - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        """
        outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
        num_sequences = 1 if isinstance(sequences, str) else len(sequences)
        candidate_labels = self._args_parser._parse_labels(candidate_labels)
        reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))

        if len(candidate_labels) == 1:
            multi_class = True

        if not multi_class:
            # softmax the "entailment" logits over all candidate labels
            entail_logits = reshaped_outputs[..., -1]
            scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
        else:
            # softmax over the entailment vs. contradiction dim for each label independently
            entail_contr_logits = reshaped_outputs[..., [0, -1]]
            scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
            scores = scores[..., 1]

        result = []
        for iseq in range(num_sequences):
            top_inds = list(reversed(scores[iseq].argsort()))
            result.append(
                {
1087
                    "sequence": sequences if isinstance(sequences, str) else sequences[iseq],
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
                    "labels": [candidate_labels[i] for i in top_inds],
                    "scores": scores[iseq][top_inds].tolist(),
                }
            )

        if len(result) == 1:
            return result[0]
        return result


Sylvain Gugger's avatar
Sylvain Gugger committed
1098
1099
1100
1101
1102
1103
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        topk (:obj:`int`, defaults to 5): The number of predictions to return.
    """,
)
Julien Chaumond's avatar
Julien Chaumond committed
1104
1105
class FillMaskPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1106
1107
    Masked language modeling prediction pipeline using any :obj:`ModelWithLMHead`. See the
    `masked language modeling examples <../task_summary.html#masked-language-modeling>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1108

Sylvain Gugger's avatar
Sylvain Gugger committed
1109
1110
    This mask filling pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"fill-mask"`.
Lysandre Debut's avatar
Lysandre Debut committed
1111
1112
1113

    The models that this pipeline can use are models that have been trained with a masked language modeling objective,
    which includes the bi-directional models in the library.
1114
1115
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=lm-head>`__.
Lysandre Debut's avatar
Lysandre Debut committed
1116

Sylvain Gugger's avatar
Sylvain Gugger committed
1117
    .. note::
Lysandre Debut's avatar
Lysandre Debut committed
1118

Sylvain Gugger's avatar
Sylvain Gugger committed
1119
        This pipeline only works for inputs with exactly one token masked.
Julien Chaumond's avatar
Julien Chaumond committed
1120
1121
1122
1123
    """

    def __init__(
        self,
1124
1125
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1126
        modelcard: Optional[ModelCard] = None,
Julien Chaumond's avatar
Julien Chaumond committed
1127
1128
1129
1130
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        topk=5,
1131
        task: str = "",
Julien Chaumond's avatar
Julien Chaumond committed
1132
1133
1134
1135
1136
1137
1138
1139
1140
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
1141
            task=task,
Julien Chaumond's avatar
Julien Chaumond committed
1142
1143
        )

1144
        self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
1145

Julien Chaumond's avatar
Julien Chaumond committed
1146
1147
        self.topk = topk

1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
    def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
        numel = np.prod(masked_index.shape)
        if numel > 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
            )
        elif numel < 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"No mask_token ({self.tokenizer.mask_token}) found on the input",
            )

1163
    def __call__(self, *args, targets=None, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1164
1165
1166
1167
        """
        Fill the masked token in the text(s) given as inputs.

        Args:
1168
1169
1170
1171
1172
1173
            args (:obj:`str` or :obj:`List[str]`):
                One or several texts (or one list of prompts) with masked tokens.
            targets (:obj:`str` or :obj:`List[str]`, `optional`):
                When passed, the model will return the scores for the passed token or tokens rather than the top k
                predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
                be tokenized and the first resulting token will be used (with a warning).
Sylvain Gugger's avatar
Sylvain Gugger committed
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
            following keys:

            - **sequence** (:obj:`str`) -- The corresponding input with the mask token prediction.
            - **score** (:obj:`float`) -- The corresponding probability.
            - **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
            - **token** (:obj:`str`) -- The predicted token (to replace the masked one).
        """
Julien Chaumond's avatar
Julien Chaumond committed
1184
1185
1186
1187
1188
1189
        inputs = self._parse_and_tokenize(*args, **kwargs)
        outputs = self._forward(inputs, return_tensors=True)

        results = []
        batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)

1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        if targets is not None:
            if len(targets) == 0 or len(targets[0]) == 0:
                raise ValueError("At least one target must be provided when passed.")
            if isinstance(targets, str):
                targets = [targets]

            targets_proc = []
            for target in targets:
                target_enc = self.tokenizer.tokenize(target)
                if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
                    logger.warning(
                        "The specified target token `{}` does not exist in the model vocabulary. Replacing with `{}`.".format(
                            target, target_enc[0]
                        )
                    )
                targets_proc.append(target_enc[0])
            target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))

Julien Chaumond's avatar
Julien Chaumond committed
1208
1209
1210
1211
1212
        for i in range(batch_size):
            input_ids = inputs["input_ids"][i]
            result = []

            if self.framework == "tf":
1213
1214
1215
1216
1217
1218
                masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index)

                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1219
                probs = tf.nn.softmax(logits)
1220
1221
1222
1223
1224
1225
1226
1227
                if targets is None:
                    topk = tf.math.top_k(probs, k=self.topk)
                    values, predictions = topk.values.numpy(), topk.indices.numpy()
                else:
                    values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
                    sort_inds = tf.reverse(tf.argsort(values), [0])
                    values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
                    predictions = target_inds[sort_inds.numpy()]
Julien Chaumond's avatar
Julien Chaumond committed
1228
            else:
1229
1230
1231
1232
                masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero()

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index.numpy())
1233

1234
                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1235
                probs = logits.softmax(dim=0)
1236
1237
1238
1239
1240
1241
1242
                if targets is None:
                    values, predictions = probs.topk(self.topk)
                else:
                    values = probs[..., target_inds]
                    sort_inds = list(reversed(values.argsort(dim=-1)))
                    values = values[..., sort_inds]
                    predictions = target_inds[sort_inds]
Julien Chaumond's avatar
Julien Chaumond committed
1243
1244
1245
1246
1247
1248

            for v, p in zip(values.tolist(), predictions.tolist()):
                tokens = input_ids.numpy()
                tokens[masked_index] = p
                # Filter padding out:
                tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
1249
1250
1251
1252
1253
1254
1255
1256
                result.append(
                    {
                        "sequence": self.tokenizer.decode(tokens),
                        "score": v,
                        "token": p,
                        "token_str": self.tokenizer.convert_ids_to_tokens(p),
                    }
                )
Julien Chaumond's avatar
Julien Chaumond committed
1257
1258
1259
1260
1261
1262
1263
1264
1265

            # Append
            results += [result]

        if len(results) == 1:
            return results[0]
        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
1266
1267
1268
1269
1270
1271
1272
1273
1274
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`):
            A list of labels to ignore.
        grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to group the tokens corresponding to the same entity together in the predictions or not.
    """,
)
1275
class TokenClassificationPipeline(Pipeline):
1276
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1277
1278
    Named Entity Recognition pipeline using any :obj:`ModelForTokenClassification`. See the
    `named entity recognition examples <../task_summary.html#named-entity-recognition>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1279

Sylvain Gugger's avatar
Sylvain Gugger committed
1280
1281
1282
    This token recognition pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location
    or miscellaneous).
Lysandre Debut's avatar
Lysandre Debut committed
1283
1284

    The models that this pipeline can use are models that have been fine-tuned on a token classification task.
1285
1286
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=token-classification>`__.
1287
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1288

1289
1290
1291
1292
    default_input_names = "sequences"

    def __init__(
        self,
1293
1294
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1295
        modelcard: Optional[ModelCard] = None,
1296
1297
1298
1299
1300
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
        ignore_labels=["O"],
1301
        task: str = "",
1302
        grouped_entities: bool = False,
1303
1304
1305
1306
1307
1308
1309
1310
1311
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=binary_output,
1312
            task=task,
1313
        )
1314

1315
1316
1317
1318
1319
1320
        self.check_model_type(
            TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
        )

1321
        self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
1322
        self.ignore_labels = ignore_labels
1323
        self.grouped_entities = grouped_entities
1324

1325
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
        """
        Classify each token of the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several texts (or one list of texts) for token classification.

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in
            the corresponding input, or each entity if this pipeline was instantiated with
            :obj:`grouped_entities=True`) with the following keys:

            - **word** (:obj:`str`) -- The token/word classified.
            - **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
            - **entity** (:obj:`str`) -- The entity predicted for that token/word.
            - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
              corresponding token in the sentence.
        """
1344
        inputs = self._args_parser(*args, **kwargs)
Julien Chaumond's avatar
Julien Chaumond committed
1345
        answers = []
1346
        for sentence in inputs:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1347

1348
1349
            # Manage correct placement of the tensors
            with self.device_placement():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1350

1351
1352
                tokens = self.tokenizer(
                    sentence, return_attention_mask=False, return_tensors=self.framework, truncation=True,
1353
                )
1354
1355

                # Forward
1356
                if self.framework == "tf":
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1357
                    entities = self.model(tokens.data)[0][0].numpy()
1358
                    input_ids = tokens["input_ids"].numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1359
                else:
1360
                    with torch.no_grad():
1361
                        tokens = self.ensure_tensor_on_device(**tokens)
1362
                        entities = self.model(**tokens)[0][0].cpu().numpy()
1363
                        input_ids = tokens["input_ids"].cpu().numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1364

thomwolf's avatar
thomwolf committed
1365
1366
            score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
            labels_idx = score.argmax(axis=-1)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1367

1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
            entities = []
            # Filter to labels not in `self.ignore_labels`
            filtered_labels_idx = [
                (idx, label_idx)
                for idx, label_idx in enumerate(labels_idx)
                if self.model.config.id2label[label_idx] not in self.ignore_labels
            ]

            for idx, label_idx in filtered_labels_idx:

                entity = {
                    "word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
                    "score": score[idx][label_idx].item(),
                    "entity": self.model.config.id2label[label_idx],
                    "index": idx,
                }

                entities += [entity]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1386

1387
            # Append grouped entities
1388
            if self.grouped_entities:
1389
1390
                answers += [self.group_entities(entities)]
            # Append ungrouped entities
1391
1392
1393
            else:
                answers += [entities]

thomwolf's avatar
thomwolf committed
1394
1395
        if len(answers) == 1:
            return answers[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1396
1397
        return answers

1398
    def group_sub_entities(self, entities: List[dict]) -> dict:
1399
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1400
1401
1402
1403
        Group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1404
        """
1405
1406
        # Get the first entity in the entity group
        entity = entities[0]["entity"]
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
        scores = np.mean([entity["score"] for entity in entities])
        tokens = [entity["word"] for entity in entities]

        entity_group = {
            "entity_group": entity,
            "score": np.mean(scores),
            "word": self.tokenizer.convert_tokens_to_string(tokens),
        }
        return entity_group

1417
1418
    def group_entities(self, entities: List[dict]) -> List[dict]:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1419
1420
1421
1422
        Find and group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
        """

        entity_groups = []
        entity_group_disagg = []

        if entities:
            last_idx = entities[-1]["index"]

        for entity in entities:
            is_last_idx = entity["index"] == last_idx
            if not entity_group_disagg:
                entity_group_disagg += [entity]
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
                continue

            # If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
            # The split is meant to account for the "B" and "I" suffixes
            if (
                entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
                and entity["index"] == entity_group_disagg[-1]["index"] + 1
            ):
                entity_group_disagg += [entity]
                # Group the entities at the last entity
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
            # If the current entity is different from the previous entity, aggregate the disaggregated entity group
            else:
                entity_groups += [self.group_sub_entities(entity_group_disagg)]
                entity_group_disagg = [entity]
                # If it's the last entity, add it to the entity groups
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]

        return entity_groups

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1459

1460
NerPipeline = TokenClassificationPipeline
1461
1462


1463
1464
1465
class QuestionAnsweringArgumentHandler(ArgumentHandler):
    """
    QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped
Sylvain Gugger's avatar
Sylvain Gugger committed
1466
    to internal :class:`~transformers.SquadExample`.
1467

Sylvain Gugger's avatar
Sylvain Gugger committed
1468
1469
    QuestionAnsweringArgumentHandler manages all the possible to create a :class:`~transformers.SquadExample` from
    the command-line supplied arguments.
1470
    """
1471

1472
1473
1474
1475
    def __call__(self, *args, **kwargs):
        # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
        if args is not None and len(args) > 0:
            if len(args) == 1:
1476
                kwargs["X"] = args[0]
1477
            else:
1478
                kwargs["X"] = list(args)
1479

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1480
1481
        # Generic compatibility with sklearn and Keras
        # Batched data
1482
1483
        if "X" in kwargs or "data" in kwargs:
            inputs = kwargs["X"] if "X" in kwargs else kwargs["data"]
1484

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1485
1486
1487
1488
1489
            if isinstance(inputs, dict):
                inputs = [inputs]
            else:
                # Copy to avoid overriding arguments
                inputs = [i for i in inputs]
1490

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1491
            for i, item in enumerate(inputs):
1492
                if isinstance(item, dict):
1493
1494
                    if any(k not in item for k in ["question", "context"]):
                        raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
1495

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1496
1497
1498
                    inputs[i] = QuestionAnsweringPipeline.create_sample(**item)

                elif not isinstance(item, SquadExample):
1499
                    raise ValueError(
1500
1501
1502
                        "{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format(
                            "X" if "X" in kwargs else "data"
                        )
1503
1504
1505
                    )

            # Tabular input
1506
1507
1508
        elif "question" in kwargs and "context" in kwargs:
            if isinstance(kwargs["question"], str):
                kwargs["question"] = [kwargs["question"]]
1509

1510
1511
            if isinstance(kwargs["context"], str):
                kwargs["context"] = [kwargs["context"]]
1512

1513
1514
1515
            inputs = [
                QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"])
            ]
1516
        else:
1517
            raise ValueError("Unknown arguments {}".format(kwargs))
1518
1519
1520
1521
1522
1523
1524

        if not isinstance(inputs, list):
            inputs = [inputs]

        return inputs


Sylvain Gugger's avatar
Sylvain Gugger committed
1525
@add_end_docstrings(PIPELINE_INIT_ARGS)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1526
1527
class QuestionAnsweringPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1528
1529
    Question Answering pipeline using any :obj:`ModelForQuestionAnswering`. See the
    `question answering examples <../task_summary.html#question-answering>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1530

Sylvain Gugger's avatar
Sylvain Gugger committed
1531
1532
    This question answering pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"question-answering"`.
Lysandre Debut's avatar
Lysandre Debut committed
1533
1534

    The models that this pipeline can use are models that have been fine-tuned on a question answering task.
1535
1536
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=question-answering>`__.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1537
1538
    """

1539
1540
1541
1542
    default_input_names = "question,context"

    def __init__(
        self,
1543
1544
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1545
        modelcard: Optional[ModelCard] = None,
1546
1547
        framework: Optional[str] = None,
        device: int = -1,
1548
        task: str = "",
1549
1550
1551
1552
1553
1554
1555
1556
1557
        **kwargs
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=QuestionAnsweringArgumentHandler(),
            device=device,
1558
            task=task,
1559
            **kwargs,
1560
        )
thomwolf's avatar
thomwolf committed
1561

1562
1563
1564
1565
        self.check_model_type(
            TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
        )

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1566
    @staticmethod
1567
1568
1569
    def create_sample(
        question: Union[str, List[str]], context: Union[str, List[str]]
    ) -> Union[SquadExample, List[SquadExample]]:
1570
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1571
1572
1573
1574
        QuestionAnsweringPipeline leverages the :class:`~transformers.SquadExample` internally.
        This helper method encapsulate all the logic for converting question(s) and context(s) to
        :class:`~transformers.SquadExample`.

1575
        We currently support extractive question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
1576

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1577
        Arguments:
Sylvain Gugger's avatar
Sylvain Gugger committed
1578
1579
            question (:obj:`str` or :obj:`List[str]`): The question(s) asked.
            context (:obj:`str` or :obj:`List[str]`): The context(s) in which we will look for the answer.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1580
1581

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1582
1583
            One or a list of :class:`~transformers.SquadExample`: The corresponding
            :class:`~transformers.SquadExample` grouping question and context.
1584
1585
        """
        if isinstance(question, list):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1586
1587
1588
1589
            return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
        else:
            return SquadExample(None, question, context, None, None, None)

1590
    def __call__(self, *args, **kwargs):
1591
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1592
1593
        Answer the question(s) given as inputs by using the context(s).

1594
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
            args (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`):
                One or several :class:`~transformers.SquadExample` containing the question and context.
            X (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
                One or several :class:`~transformers.SquadExample` containing the question and context
                (will be treated the same way as if passed as the first positional argument).
            data (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
                One or several :class:`~transformers.SquadExample` containing the question and context
                (will be treated the same way as if passed as the first positional argument).
            question (:obj:`str` or :obj:`List[str]`):
                One or several question(s) (must be used in conjunction with the :obj:`context` argument).
            context (:obj:`str` or :obj:`List[str]`):
                One or several context(s) associated with the qustion(s) (must be used in conjunction with the
                :obj:`question` argument).
            topk (:obj:`int`, `optional`, defaults to 1):
                The number of answers to return (will be chosen by order of likelihood).
            doc_stride (:obj:`int`, `optional`, defaults to 128):
                If the context is too long to fit with the question for the model, it will be split in several chunks
                with some overlap. This argument controls the size of that overlap.
            max_answer_len (:obj:`int`, `optional`, defaults to 15):
                The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
            max_seq_len (:obj:`int`, `optional`, defaults to 384):
                The maximum length of the total sentence (context + question) after tokenization. The context will be
                split in several chunks (using :obj:`doc_stride`) if needed.
            max_question_len (:obj:`int`, `optional`, defaults to 64):
                The maximum length of the question after tokenization. It will be truncated if needed.
            handle_impossible_answer (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not we accept impossible as an answer.

        Return:
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:

            - **score** (:obj:`float`) -- The probability associated to the answer.
            - **start** (:obj:`int`) -- The start index of the answer (in the tokenized version of the input).
            - **end** (:obj:`int`) -- The end index of the answer (in the tokenized version of the input).
            - **answer** (:obj:`str`) -- The answer to the question.
1631
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1632
        # Set defaults values
1633
1634
1635
1636
1637
        kwargs.setdefault("topk", 1)
        kwargs.setdefault("doc_stride", 128)
        kwargs.setdefault("max_answer_len", 15)
        kwargs.setdefault("max_seq_len", 384)
        kwargs.setdefault("max_question_len", 64)
1638
        kwargs.setdefault("handle_impossible_answer", False)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1639

1640
1641
        if kwargs["topk"] < 1:
            raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1642

1643
1644
        if kwargs["max_answer_len"] < 1:
            raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1645
1646

        # Convert inputs to features
1647
        examples = self._args_parser(*args, **kwargs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1648
1649
        features_list = [
            squad_convert_examples_to_features(
1650
1651
1652
1653
1654
                examples=[example],
                tokenizer=self.tokenizer,
                max_seq_length=kwargs["max_seq_len"],
                doc_stride=kwargs["doc_stride"],
                max_query_length=kwargs["max_question_len"],
1655
                padding_strategy=PaddingStrategy.DO_NOT_PAD.value,
1656
                is_training=False,
1657
                tqdm_enabled=False,
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1658
1659
1660
            )
            for example in examples
        ]
Rishabh Manoj's avatar
Rishabh Manoj committed
1661
1662
        all_answers = []
        for features, example in zip(features_list, examples):
Patrick von Platen's avatar
Patrick von Platen committed
1663
1664
            model_input_names = self.tokenizer.model_input_names + ["input_ids"]
            fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
Rishabh Manoj's avatar
Rishabh Manoj committed
1665
1666
1667
1668
1669

            # Manage tensor allocation on correct device
            with self.device_placement():
                if self.framework == "tf":
                    fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1670
                    start, end = self.model(fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1671
1672
1673
1674
1675
                    start, end = start.numpy(), end.numpy()
                else:
                    with torch.no_grad():
                        # Retrieve the score for the context tokens only (removing question tokens)
                        fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1676
                        start, end = self.model(**fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1677
1678
                        start, end = start.cpu().numpy(), end.cpu().numpy()

1679
            min_null_score = 1000000  # large and positive
Rishabh Manoj's avatar
Rishabh Manoj committed
1680
1681
            answers = []
            for (feature, start_, end_) in zip(features, start, end):
1682
1683
                # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
                undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask
Rishabh Manoj's avatar
Rishabh Manoj committed
1684

1685
1686
1687
1688
1689
1690
                # Generate mask
                undesired_tokens_mask = undesired_tokens == 0.0

                # Make sure non-context indexes in the tensor cannot contribute to the softmax
                start_ = np.where(undesired_tokens_mask, -10000.0, start_)
                end_ = np.where(undesired_tokens_mask, -10000.0, end_)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1691
1692
1693
1694
1695

                # Normalize logits and spans to retrieve the answer
                start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
                end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

1696
1697
1698
                if kwargs["handle_impossible_answer"]:
                    min_null_score = min(min_null_score, (start_[0] * end_[0]).item())

1699
1700
1701
                # Mask CLS
                start_[0] = end_[0] = 0.0

Rishabh Manoj's avatar
Rishabh Manoj committed
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
                starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
                char_to_word = np.array(example.char_to_word_offset)

                # Convert the answer (tokens) back to the original text
                answers += [
                    {
                        "score": score.item(),
                        "start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
                        "end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
                        "answer": " ".join(
                            example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
                        ),
                    }
                    for s, e, score in zip(starts, ends, scores)
                ]
1717
1718
1719
1720

            if kwargs["handle_impossible_answer"]:
                answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1721
1722
1723
            answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
            all_answers += answers

Rishabh Manoj's avatar
Rishabh Manoj committed
1724
        if len(all_answers) == 1:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1725
            return all_answers[0]
Rishabh Manoj's avatar
Rishabh Manoj committed
1726
        return all_answers
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1727
1728

    def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
1729
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1730
        Take the output of any :obj:`ModelForQuestionAnswering` and will generate probalities for each span to be
1731
        the actual answer.
Sylvain Gugger's avatar
Sylvain Gugger committed
1732

1733
1734
1735
1736
1737
        In addition, it filters out some unwanted/impossible cases like answer len being greater than
        max_answer_len or answer end position being before the starting position.
        The method supports output the k-best answer through the topk argument.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1738
1739
1740
1741
            start (:obj:`np.ndarray`): Individual start probabilities for each token.
            end (:obj:`np.ndarray`): Individual end probabilities for each token.
            topk (:obj:`int`): Indicates how many possible answer span(s) to extract from the model output.
            max_answer_len (:obj:`int`): Maximum size of the answer to extract from the model's output.
1742
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
        # Ensure we have batch axis
        if start.ndim == 1:
            start = start[None]

        if end.ndim == 1:
            end = end[None]

        # Compute the score of each tuple(start, end) to be the real answer
        outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))

        # Remove candidate with end < start and end - start > max_answer_len
        candidates = np.tril(np.triu(outer), max_answer_len - 1)

        #  Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
        scores_flat = candidates.flatten()
        if topk == 1:
            idx_sort = [np.argmax(scores_flat)]
        elif len(scores_flat) < topk:
            idx_sort = np.argsort(-scores_flat)
        else:
            idx = np.argpartition(-scores_flat, topk)[0:topk]
            idx_sort = idx[np.argsort(-scores_flat[idx])]

        start, end = np.unravel_index(idx_sort, candidates.shape)[1:]
        return start, end, candidates[0, start, end]

Sylvain Gugger's avatar
Sylvain Gugger committed
1769
    def span_to_answer(self, text: str, start: int, end: int) -> Dict[str, Union[str, int]]:
1770
1771
1772
1773
1774
        """
        When decoding from token probalities, this method maps token indexes to actual word in
        the initial context.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1775
1776
1777
            text (:obj:`str`): The actual context to extract the answer from.
            start (:obj:`int`): The answer starting token index.
            end (:obj:`int`): The answer end token index.
1778
1779

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1780
            Dictionary like :obj:`{'answer': str, 'start': int, 'end': int}`
1781
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
        words = []
        token_idx = char_start_idx = char_end_idx = chars_idx = 0

        for i, word in enumerate(text.split(" ")):
            token = self.tokenizer.tokenize(word)

            # Append words if they are in the span
            if start <= token_idx <= end:
                if token_idx == start:
                    char_start_idx = chars_idx

                if token_idx == end:
                    char_end_idx = chars_idx + len(word)

                words += [word]

            # Stop if we went over the end of the answer
            if token_idx > end:
                break

            # Append the subtokenization length to the running index
            token_idx += len(token)
            chars_idx += len(word) + 1

        # Join text with spaces
1807
1808
1809
1810
1811
        return {
            "answer": " ".join(words),
            "start": max(0, char_start_idx),
            "end": min(len(text), char_end_idx),
        }
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1812
1813


Sylvain Gugger's avatar
Sylvain Gugger committed
1814
@add_end_docstrings(PIPELINE_INIT_ARGS)
1815
1816
class SummarizationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1817
1818
1819
1820
1821
1822
1823
1824
1825
    Summarize news articles and other documents.

    This summarizing pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"summarization"`.

    The models that this pipeline can use are models that have been fine-tuned on a summarization task,
    which is currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'.
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=summarization>`__.
1826
1827
1828

    Usage::

1829
        # use bart in pytorch
1830
        summarizer = pipeline("summarization")
1831
1832
1833
1834
1835
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)

        # use t5 in tf
        summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
1836
1837
    """

1838
    def __init__(self, *args, **kwargs):
1839
        kwargs.update(task="summarization")
1840
1841
1842
1843
1844
        super().__init__(*args, **kwargs)

        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
1845

1846
    def __call__(
1847
        self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
1848
1849
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1850
        Summarize the text(s) given as inputs.
1851

Sylvain Gugger's avatar
Sylvain Gugger committed
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
        Args:
            documents (`str` or :obj:`List[str]`):
                One or several articles (or one list of articles) to summarize.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
1864

Sylvain Gugger's avatar
Sylvain Gugger committed
1865
1866
1867
        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
1868

Sylvain Gugger's avatar
Sylvain Gugger committed
1869
1870
1871
1872
            - **summary_text** (:obj:`str`, present when ``return_text=True``) -- The summary of the corresponding
              input.
            - **summary_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the summary.
1873
1874
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
        assert len(documents) > 0, "Please provide a document to summarize"

        if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
            raise NotImplementedError(
                "Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
            )

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

        if isinstance(documents[0], list):
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"

            documents = ([prefix + document for document in documents[0]],)
1890
            padding = True
1891
1892
1893

        elif isinstance(documents[0], str):
            documents = (prefix + documents[0],)
1894
            padding = False
1895
1896
1897
1898
1899
1900
1901
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
                    documents[0]
                )
            )

1902
        with self.device_placement():
1903
            inputs = self._parse_and_tokenize(*documents, padding=padding)
1904
1905
1906
1907
1908

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]
            elif self.framework == "tf":
1909
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()
1910

1911
1912
            min_length = generate_kwargs.get("min_length", self.model.config.min_length)
            if input_length < min_length // 2:
1913
                logger.warning(
1914
                    "Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
1915
                        min_length, input_length
1916
1917
1918
                    )
                )

1919
1920
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length < max_length:
1921
                logger.warning(
1922
                    "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
1923
                        max_length, input_length
1924
1925
1926
                    )
                )

1927
            summaries = self.model.generate(
1928
                inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
1929
            )
1930

1931
1932
1933
1934
1935
1936
1937
            results = []
            for summary in summaries:
                record = {}
                if return_tensors:
                    record["summary_token_ids"] = summary
                if return_text:
                    record["summary_text"] = self.tokenizer.decode(
1938
1939
1940
1941
1942
1943
                        summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                results.append(record)
            return results


Sylvain Gugger's avatar
Sylvain Gugger committed
1944
@add_end_docstrings(PIPELINE_INIT_ARGS)
1945
1946
1947
1948
class TranslationPipeline(Pipeline):
    """
    Translates from one language to another.

Sylvain Gugger's avatar
Sylvain Gugger committed
1949
1950
    This translation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"translation_xx_to_yy"`.
1951

Sylvain Gugger's avatar
Sylvain Gugger committed
1952
    The models that this pipeline can use are models that have been fine-tuned on a translation task.
1953
1954
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=translation>`__.
1955

Sylvain Gugger's avatar
Sylvain Gugger committed
1956
1957
1958
    Usage::
        en_fr_translator = pipeline("translation_en_to_fr")
        en_fr_translator("How old are you?")
1959
1960
    """

1961
1962
1963
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

1964
1965
1966
        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
1967

1968
    def __call__(
1969
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
1970
1971
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1972
1973
        Translate the text(s) given as inputs.

1974
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
            args (:obj:`str` or :obj:`List[str]`):
                Texts to be translated.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indinces) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
1986

Sylvain Gugger's avatar
Sylvain Gugger committed
1987
1988
1989
        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the
            following keys:
1990

Sylvain Gugger's avatar
Sylvain Gugger committed
1991
1992
1993
            - **translation_text** (:obj:`str`, present when ``return_text=True``) -- The translation.
            - **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the translation.
1994
1995
1996
1997
1998
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

1999
        if isinstance(args[0], list):
2000
2001
2002
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
2003
            args = ([prefix + text for text in args[0]],)
2004
            padding = True
2005

2006
2007
        elif isinstance(args[0], str):
            args = (prefix + args[0],)
2008
            padding = False
2009
2010
2011
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
2012
                    args[0]
2013
2014
2015
2016
                )
            )

        with self.device_placement():
2017
            inputs = self._parse_and_tokenize(*args, padding=padding)
2018
2019
2020
2021
2022
2023
2024
2025

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

2026
2027
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length > 0.9 * max_length:
2028
2029
                logger.warning(
                    "Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
2030
                        input_length, max_length
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
                    )
                )

            translations = self.model.generate(
                inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
            )
            results = []
            for translation in translations:
                record = {}
                if return_tensors:
                    record["translation_token_ids"] = translation
                if return_text:
                    record["translation_text"] = self.tokenizer.decode(
                        translation,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
2047
2048
2049
2050
2051
                    )
                results.append(record)
            return results


2052
2053
2054
class Conversation:
    """
    Utility class containing a conversation and its history. This class is meant to be used as an input to the
Sylvain Gugger's avatar
Sylvain Gugger committed
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
    :class:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the
    addition of new user input and generated model responses. A conversation needs to contain an unprocessed user input
    before being passed to the :class:`~transformers.ConversationalPipeline`. This user input is either created when
    the class is instantiated, or by calling :obj:`conversional_pipeline.append_response("input")` after a conversation
    turn.

    Arguments:
        text (:obj:`str`, `optional`):
            The initial user input to start the conversation. If not provided, a user input needs to be provided
            manually using the :meth:`~transformers.Conversation.add_user_input` method before the conversation can
            begin.
        conversation_id (:obj:`uuid.UUID`, `optional`):
            Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
            conversation.
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093

    Usage::

        conversation = Conversation("Going to the movies tonight - any suggestions?")

        # Steps usually performed by the model when generating a response:
        # 1. Mark the user input as processed (moved to the history)
        conversation.mark_processed()
        # 2. Append a mode response
        conversation.append_response("The Big lebowski.")

        conversation.add_user_input("Is it good?")
    """

    def __init__(self, text: str = None, conversation_id: UUID = None):
        if not conversation_id:
            conversation_id = uuid.uuid4()
        self.uuid: UUID = conversation_id
        self.past_user_inputs: List[str] = []
        self.generated_responses: List[str] = []
        self.history: List[int] = []
        self.new_user_input: Optional[str] = text

    def add_user_input(self, text: str, overwrite: bool = False):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2094
2095
        Add a user input to the conversation for the next round. This populates the internal :obj:`new_user_input`
        field.
2096
2097

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2098
2099
2100
            text (:obj:`str`): The user input for the next conversation round.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not existing and unprocessed user input should be overwritten when this function is called.
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
        """
        if self.new_user_input:
            if overwrite:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format(
                        self.new_user_input, text
                    )
                )
                self.new_user_input = text
            else:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" new input ignored: "{}". '
                    "Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text)
                )
        else:
            self.new_user_input = text

    def mark_processed(self):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2120
2121
        Mark the conversation as processed (moves the content of :obj:`new_user_input` to :obj:`past_user_inputs`) and
        empties the :obj:`new_user_input` field.
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
        """
        if self.new_user_input:
            self.past_user_inputs.append(self.new_user_input)
        self.new_user_input = None

    def append_response(self, response: str):
        """
        Append a response to the list of generated responses.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2132
            response (:obj:`str`): The model generated response.
2133
2134
2135
2136
2137
        """
        self.generated_responses.append(response)

    def set_history(self, history: List[int]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2138
2139
        Updates the value of the history of the conversation. The history is represented by a list of :obj:`token_ids`.
        The history is used by the model to generate responses based on the previous conversation turns.
2140
2141

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2142
            history (:obj:`List[int]`): History of tokens provided and generated for this conversation.
2143
2144
2145
2146
2147
2148
2149
2150
        """
        self.history = history

    def __repr__(self):
        """
        Generates a string representation of the conversation.

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2151
            :obj:`str`:
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166

            Example:
            Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114
            user >> Going to the movies tonight - any suggestions?
            bot >> The Big Lebowski
        """
        output = "Conversation id: {} \n".format(self.uuid)
        for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
            output += "user >> {} \n".format(user_input)
            output += "bot >> {} \n".format(generated_response)
        if self.new_user_input is not None:
            output += "user >> {} \n".format(self.new_user_input)
        return output


Sylvain Gugger's avatar
Sylvain Gugger committed
2167
2168
2169
2170
2171
2172
2173
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        min_length_for_response (:obj:`int`, `optional`, defaults to 32):
            The minimum length (in number of tokens) for a response.
    """,
)
2174
2175
2176
2177
class ConversationalPipeline(Pipeline):
    """
    Multi-turn conversational pipeline.

Sylvain Gugger's avatar
Sylvain Gugger committed
2178
2179
2180
2181
2182
2183
2184
2185
    This conversational pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"conversational"`.

    The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task,
    currently: `'microsoft/DialoGPT-small'`, `'microsoft/DialoGPT-medium'`, `'microsoft/DialoGPT-large'`.
    See the up-to-date list of available models on
    `huggingface.co/models <https://huggingface.co/models?filter=conversational>`__.

2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
    Usage::

        conversational_pipeline = pipeline("conversational")

        conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
        conversation_2 = Conversation("What's the last book you have read?")

        conversational_pipeline([conversation_1, conversation_2])

        conversation_1.add_user_input("Is it an action movie?")
        conversation_2.add_user_input("What is the genre of this book?")

        conversational_pipeline([conversation_1, conversation_2])
    """

    def __init__(self, min_length_for_response=32, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
        if self.tokenizer.pad_token_id is not None:
            self.pad_token_id = self.tokenizer.pad_token_id
        else:
            self.pad_token_id = self.tokenizer.eos_token_id
        self.min_length_for_response = min_length_for_response

    def __call__(
        self,
        conversations: Union[Conversation, List[Conversation]],
        clean_up_tokenization_spaces=True,
        **generate_kwargs
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
2217
2218
        Generate responses for the conversation(s) given as inputs.

2219
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2220
2221
2222
2223
2224
2225
2226
            conversations (a :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`):
                Conversations to generate responses for.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate
                method corresponding to your framework `here <./model.html#generative-models>`__).
2227
2228

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2229
2230
            :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`: Conversation(s) with
            updated generated responses for those containing a new user input.
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
        """

        # Input validation
        if isinstance(conversations, list):
            for conversation in conversations:
                assert isinstance(
                    conversation, Conversation
                ), "DialoguePipeline expects a Conversation or list of Conversations as an input"
                if conversation.new_user_input is None:
                    raise ValueError(
                        "Conversation with UUID {} does not contain new user input to process. "
                        "Add user inputs with the conversation's `add_user_input` method".format(
                            type(conversation.uuid)
                        )
                    )
            assert (
                self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
        elif isinstance(conversations, Conversation):
            conversations = [conversations]
        else:
            raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")

        with self.device_placement():

            inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
            histories = [conversation.history for conversation in conversations]
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            inputs = self._concat_inputs_history(inputs, histories, max_length)

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

            if input_length > 0.9 * max_length:
                logger.warning(
                    "Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
                    "You might consider trimming the early phase of the conversation".format(input_length, max_length)
                )
            generated_responses = self.model.generate(
                inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
            )

            cleaned_history = self._clean_padding_history(generated_responses)
            output = []
            for conversation_index, conversation in enumerate(conversations):
                conversation.mark_processed()
                conversation.generated_responses.append(
                    self.tokenizer.decode(
                        cleaned_history[conversation_index][input_length:],
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                )
                conversation.set_history(cleaned_history[conversation_index])
                output.append(conversation)
            if len(output) == 1:
                return output[0]
            else:
                return output

    def _parse_and_tokenize(self, *args, **kwargs):
        """
        Parse arguments and tokenize, adding an EOS token at the end of the user input
        """
        # Parse arguments
        inputs = self._args_parser(*args, **kwargs)
        inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
        for input in inputs:
            input.append(self.tokenizer.eos_token_id)
        return inputs

    def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
        """
        Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
        an input:
            - at the end of the concatenated history and new user input, so that all input to the model have the same
                length
            - at the end of the generated response, as some responses will be longer than others
        This method cleans up these padding token so that the history for each conversation is not impacted by the
        batching process.
        """
        outputs = []
        for sequence in generated_tensor:
            sequence_tokens = []
            is_previous_pad = False
            for token in sequence:
                if token == self.pad_token_id:
                    if is_previous_pad:
                        continue
                    else:
                        is_previous_pad = True
                else:
                    is_previous_pad = False
                if self.framework == "pt":
                    sequence_tokens.append(token.item())
                else:
                    sequence_tokens.append(int(token.numpy()))

            outputs.append(sequence_tokens)
        return outputs

    def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
        """
        Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
        """
        outputs = []
        for new_input, history in zip(inputs, histories):
            if history is not None:
                new_input = history + new_input
            if len(new_input) > max_length - self.min_length_for_response:
                cutoff_eos_index = 0
                while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
                    if cutoff_eos_index >= len(new_input):
                        break
                    cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
                    if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
                        break
                    else:
                        new_input = new_input[cutoff_eos_index + 1 :]
            outputs.append(new_input)
        max_len = max([len(item) for item in outputs])
        outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
        outputs = BatchEncoding(
2358
            {"input_ids": outputs, "attention_mask": [[1] * len(outputs)]}, tensor_type=self.framework,
2359
2360
2361
2362
        )
        return outputs


2363
# Register all the supported tasks here
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2364
SUPPORTED_TASKS = {
2365
2366
2367
2368
    "feature-extraction": {
        "impl": FeatureExtractionPipeline,
        "tf": TFAutoModel if is_tf_available() else None,
        "pt": AutoModel if is_torch_available() else None,
2369
        "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
2370
    },
2371
2372
2373
2374
2375
2376
    "sentiment-analysis": {
        "impl": TextClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {
2377
2378
                "pt": "distilbert-base-uncased-finetuned-sst-2-english",
                "tf": "distilbert-base-uncased-finetuned-sst-2-english",
2379
            },
2380
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2381
    },
2382
    "ner": {
2383
        "impl": TokenClassificationPipeline,
2384
2385
2386
2387
        "tf": TFAutoModelForTokenClassification if is_tf_available() else None,
        "pt": AutoModelForTokenClassification if is_torch_available() else None,
        "default": {
            "model": {
Julien Chaumond's avatar
Julien Chaumond committed
2388
2389
                "pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
                "tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
2390
            },
2391
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2392
    },
2393
2394
2395
2396
2397
    "question-answering": {
        "impl": QuestionAnsweringPipeline,
        "tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
        "pt": AutoModelForQuestionAnswering if is_torch_available() else None,
        "default": {
Lysandre's avatar
E231  
Lysandre committed
2398
            "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
2399
2400
        },
    },
Julien Chaumond's avatar
Julien Chaumond committed
2401
2402
2403
    "fill-mask": {
        "impl": FillMaskPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2404
        "pt": AutoModelForMaskedLM if is_torch_available() else None,
2405
        "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
Julien Chaumond's avatar
Julien Chaumond committed
2406
    },
2407
2408
    "summarization": {
        "impl": SummarizationPipeline,
2409
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2410
2411
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
        "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
2412
    },
2413
2414
2415
    "translation_en_to_fr": {
        "impl": TranslationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2416
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2417
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
2418
2419
2420
2421
    },
    "translation_en_to_de": {
        "impl": TranslationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2422
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2423
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
2424
2425
2426
2427
    },
    "translation_en_to_ro": {
        "impl": TranslationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2428
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2429
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
2430
    },
2431
2432
2433
    "text-generation": {
        "impl": TextGenerationPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
2434
        "pt": AutoModelForCausalLM if is_torch_available() else None,
2435
        "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
2436
    },
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
    "zero-shot-classification": {
        "impl": ZeroShotClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
        },
    },
2447
2448
2449
2450
2451
2452
    "conversational": {
        "impl": ConversationalPipeline,
        "tf": TFAutoModelForCausalLM if is_tf_available() else None,
        "pt": AutoModelForCausalLM if is_torch_available() else None,
        "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
    },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2453
2454
2455
}


2456
2457
2458
2459
2460
def pipeline(
    task: str,
    model: Optional = None,
    config: Optional[Union[str, PretrainedConfig]] = None,
    tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
2461
    framework: Optional[str] = None,
2462
2463
    **kwargs
) -> Pipeline:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2464
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
2465
    Utility factory method to build a :class:`~transformers.Pipeline`.
Lysandre Debut's avatar
Lysandre Debut committed
2466

Sylvain Gugger's avatar
Sylvain Gugger committed
2467
    Pipelines are made of:
Lysandre Debut's avatar
Lysandre Debut committed
2468

Sylvain Gugger's avatar
Sylvain Gugger committed
2469
2470
2471
        - A :doc:`tokenizer <tokenizer>` in charge of mapping raw textual input to token.
        - A :doc:`model <model>` to make predictions from the inputs.
        - Some (optional) post processing for enhancing model's output.
Lysandre Debut's avatar
Lysandre Debut committed
2472
2473
2474
2475
2476

    Args:
        task (:obj:`str`):
            The task defining which pipeline will be returned. Currently accepted tasks are:

Sylvain Gugger's avatar
Sylvain Gugger committed
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
            - :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
            - :obj:`"sentiment-analysis"`: will return a :class:`~transformers.TextClassificationPipeline`.
            - :obj:`"ner"`: will return a :class:`~transformers.TokenClassificationPipeline`.
            - :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
            - :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
            - :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.
            - :obj:`"translation_xx_to_yy"`: will return a :class:`~transformers.TranslationPipeline`.
            - :obj:`"text-generation"`: will return a :class:`~transformers.TextGenerationPipeline`.
            - :obj:`"conversation"`: will return a :class:`~transformers.ConversationalPipeline`.
        model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`):
            The model that will be used by the pipeline to make predictions. This can be a model identifier or an
            actual instance of a pretrained model inheriting from :class:`~transformers.PreTrainedModel` (for PyTorch)
            or :class:`~transformers.TFPreTrainedModel` (for TensorFlow).

            If not provided, the default for the :obj:`task` will be loaded.
        config (:obj:`str` or :obj:`~transformers.PretrainedConfig`, `optional`):
            The configuration that will be used by the pipeline to instantiate the model. This can be a model
            identifier or an actual pretrained model configuration inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
2495
2496
            :class:`~transformers.PretrainedConfig`.

Sylvain Gugger's avatar
Sylvain Gugger committed
2497
2498
2499
2500
            If not provided, the default for the :obj:`task` will be loaded.
        tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`):
            The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
            identifier or an actual pretrained tokenizer inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
2501
2502
            :class:`~transformers.PreTrainedTokenizer`.

Sylvain Gugger's avatar
Sylvain Gugger committed
2503
2504
2505
2506
            If not provided, the default for the :obj:`task` will be loaded.
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
2507
2508

            If no framework is specified, will default to the one currently installed. If no framework is specified
Sylvain Gugger's avatar
Sylvain Gugger committed
2509
2510
2511
2512
2513
            and both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no
            model is provided.
        kwargs:
            Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
            corresponding pipeline class for possible values).
Lysandre Debut's avatar
Lysandre Debut committed
2514
2515

    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2516
        :class:`~transformers.Pipeline`: A suitable pipeline for the task.
Lysandre Debut's avatar
Lysandre Debut committed
2517
2518
2519
2520
2521
2522

    Examples::

        from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer

        # Sentiment analysis pipeline
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2523
        pipeline('sentiment-analysis')
Lysandre Debut's avatar
Lysandre Debut committed
2524
2525

        # Question answering pipeline, specifying the checkpoint identifier
2526
        pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased')
Lysandre Debut's avatar
Lysandre Debut committed
2527
2528
2529
2530
2531

        # Named entity recognition pipeline, passing in a specific model and tokenizer
        model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
        tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        pipeline('ner', model=model, tokenizer=tokenizer)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2532
    """
2533
    # Retrieve the task
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2534
2535
2536
    if task not in SUPPORTED_TASKS:
        raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))

2537
    framework = framework or get_framework(model)
2538

Morgan Funtowicz's avatar
Morgan Funtowicz committed
2539
    targeted_task = SUPPORTED_TASKS[task]
2540
    task_class, model_class = targeted_task["impl"], targeted_task[framework]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2541

2542
    # Use default model/config/tokenizer for the task if no model is provided
2543
    if model is None:
2544
        model = targeted_task["default"]["model"][framework]
2545

2546
2547
    # Try to infer tokenizer from model or config name (if provided as str)
    if tokenizer is None:
2548
        if isinstance(model, str):
2549
            tokenizer = model
2550
        elif isinstance(config, str):
2551
2552
2553
            tokenizer = config
        else:
            # Impossible to guest what is the right tokenizer here
2554
2555
            raise Exception(
                "Impossible to guess which tokenizer to use. "
2556
                "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
2557
            )
2558

Lysandre Debut's avatar
Lysandre Debut committed
2559
    modelcard = None
2560
    # Try to infer modelcard from model or config name (if provided as str)
Lysandre Debut's avatar
Lysandre Debut committed
2561
2562
2563
2564
    if isinstance(model, str):
        modelcard = model
    elif isinstance(config, str):
        modelcard = config
2565
2566

    # Instantiate tokenizer if needed
2567
2568
2569
2570
2571
2572
    if isinstance(tokenizer, (str, tuple)):
        if isinstance(tokenizer, tuple):
            # For tuple we have (tokenizer name, {kwargs})
            tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
        else:
            tokenizer = AutoTokenizer.from_pretrained(tokenizer)
2573
2574
2575
2576

    # Instantiate config if needed
    if isinstance(config, str):
        config = AutoConfig.from_pretrained(config)
2577

thomwolf's avatar
thomwolf committed
2578
2579
2580
2581
    # Instantiate modelcard if needed
    if isinstance(modelcard, str):
        modelcard = ModelCard.from_pretrained(modelcard)

2582
    # Instantiate model if needed
2583
    if isinstance(model, str):
2584
2585
        # Handle transparent TF/PT model conversion
        model_kwargs = {}
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
        if framework == "pt" and model.endswith(".h5"):
            model_kwargs["from_tf"] = True
            logger.warning(
                "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
                "Trying to load the model with PyTorch."
            )
        elif framework == "tf" and model.endswith(".bin"):
            model_kwargs["from_pt"] = True
            logger.warning(
                "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
                "Trying to load the model with Tensorflow."
            )
2598
        model = model_class.from_pretrained(model, config=config, **model_kwargs)
2599

2600
    return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)