"sims/nic/vscode:/vscode.git/clone" did not exist on "32a74b8eb0183e6911e87292795f8a8968d214af"
pipelines.py 118 KB
Newer Older
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

Morgan Funtowicz's avatar
Morgan Funtowicz committed
16

17
18
import csv
import json
Morgan Funtowicz's avatar
Morgan Funtowicz committed
19
import os
20
import pickle
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import sys
22
import uuid
23
import warnings
Morgan Funtowicz's avatar
Morgan Funtowicz committed
24
from abc import ABC, abstractmethod
25
from contextlib import contextmanager
26
from os.path import abspath, exists
27
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
28
from uuid import UUID
Morgan Funtowicz's avatar
Morgan Funtowicz committed
29
30
31

import numpy as np

32
from .configuration_auto import AutoConfig
33
34
from .configuration_utils import PretrainedConfig
from .data import SquadExample, squad_convert_examples_to_features
Sylvain Gugger's avatar
Sylvain Gugger committed
35
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available
36
37
38
39
from .modelcard import ModelCard
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
40
from .tokenization_utils_base import PaddingStrategy
Lysandre Debut's avatar
Lysandre Debut committed
41
from .utils import logging
Morgan Funtowicz's avatar
Morgan Funtowicz committed
42

Aymeric Augustin's avatar
Aymeric Augustin committed
43

Morgan Funtowicz's avatar
Morgan Funtowicz committed
44
if is_tf_available():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
45
    import tensorflow as tf
46

47
    from .modeling_tf_auto import (
48
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
49
        TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
50
51
52
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
        TF_MODEL_WITH_LM_HEAD_MAPPING,
53
        TFAutoModel,
54
        TFAutoModelForCausalLM,
55
        TFAutoModelForMaskedLM,
56
        TFAutoModelForQuestionAnswering,
57
        TFAutoModelForSeq2SeqLM,
58
        TFAutoModelForSequenceClassification,
59
60
        TFAutoModelForTokenClassification,
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
61
62
63

if is_torch_available():
    import torch
64

65
    from .modeling_auto import (
66
67
68
69
70
        MODEL_FOR_MASKED_LM_MAPPING,
        MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
        MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
71
        AutoModel,
72
73
        AutoModelForCausalLM,
        AutoModelForMaskedLM,
74
75
76
77
        AutoModelForQuestionAnswering,
        AutoModelForSeq2SeqLM,
        AutoModelForSequenceClassification,
        AutoModelForTokenClassification,
78
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
79

80
81
if TYPE_CHECKING:
    from .modeling_tf_utils import TFPreTrainedModel
82
    from .modeling_utils import PreTrainedModel
83

Morgan Funtowicz's avatar
Morgan Funtowicz committed
84

Lysandre Debut's avatar
Lysandre Debut committed
85
logger = logging.get_logger(__name__)
86

87

88
def get_framework(model):
Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
91
92
    """
    Select framework (TensorFlow or PyTorch) to use.

    Args:
93
        model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
Sylvain Gugger's avatar
Sylvain Gugger committed
94
95
            If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
            the model name). If no specific model is provided, defaults to using PyTorch.
96
    """
97
    if not is_tf_available() and not is_torch_available():
Aymeric Augustin's avatar
Aymeric Augustin committed
98
        raise RuntimeError(
99
100
101
102
            "At least one of TensorFlow 2.0 or PyTorch should be installed. "
            "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
            "To install PyTorch, read the instructions at https://pytorch.org/."
        )
103
104
105
106
107
108
109
110
111
112
113
114
    if isinstance(model, str):
        if is_torch_available() and not is_tf_available():
            model = AutoModel.from_pretrained(model)
        elif is_tf_available() and not is_torch_available():
            model = TFAutoModel.from_pretrained(model)
        else:
            try:
                model = AutoModel.from_pretrained(model)
            except OSError:
                model = TFAutoModel.from_pretrained(model)

    framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
thomwolf's avatar
thomwolf committed
115
116
    return framework

117

118
def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str:
119
120
121
122
123
    """
    Select a default model to use for a given task. Defaults to pytorch if ambiguous.

    Args:
        targeted_task (:obj:`Dict` ):
124
           Dictionary representing the given task, that should contain default models
125
126
127
128

        framework (:obj:`str`, None)
           "pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.

129
        task_options (:obj:`Any`, None)
Sylvain Gugger's avatar
Sylvain Gugger committed
130
131
           Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for
           translation task.
132

133
134
135
136
137
138
139
140
141
    Returns

        :obj:`str` The model string representing the default model for this pipeline
    """
    if is_torch_available() and not is_tf_available():
        framework = "pt"
    elif is_tf_available() and not is_torch_available():
        framework = "tf"

142
143
144
145
146
147
148
149
150
151
    defaults = targeted_task["default"]
    if task_options:
        if task_options not in defaults:
            raise ValueError("The task does not provide any default models for options {}".format(task_options))
        default_models = defaults[task_options]["model"]
    elif "model" in defaults:
        default_models = targeted_task["default"]["model"]
    else:
        # XXX This error message needs to be updated to be more generic if more tasks are going to become
        # parametrized
152
        raise ValueError('The task defaults can\'t be correctly selected. You probably meant "translation_XX_to_YY"')
153

154
155
156
157
158
159
    if framework is None:
        framework = "pt"

    return default_models[framework]


160
161
class PipelineException(Exception):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
162
163
164
165
166
167
    Raised by a :class:`~transformers.Pipeline` when handling __call__.

    Args:
        task (:obj:`str`): The task of the pipeline.
        model (:obj:`str`): The model used by the pipeline.
        reason (:obj:`str`): The error message to display.
168
169
170
171
172
173
174
175
176
    """

    def __init__(self, task: str, model: str, reason: str):
        super().__init__(reason)

        self.task = task
        self.model = model


177
178
class ArgumentHandler(ABC):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
179
    Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
180
    """
181

182
183
184
    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplementedError()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
185
186


187
class PipelineDataFormat:
188
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
189
190
191
    Base class for all the pipeline supported data format both for reading and writing. Supported data formats
    currently includes:

Sylvain Gugger's avatar
Sylvain Gugger committed
192
193
194
    - JSON
    - CSV
    - stdin/stdout (pipe)
195

Sylvain Gugger's avatar
Sylvain Gugger committed
196
197
198
199
200
201
202
203
204
    :obj:`PipelineDataFormat` also includes some utilities to work with multi-columns like mapping from datasets
    columns to pipelines keyword arguments through the :obj:`dataset_kwarg_1=dataset_column_1` format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
205
    """
206
207

    SUPPORTED_FORMATS = ["json", "csv", "pipe"]
208

209
    def __init__(
Lysandre's avatar
Lysandre committed
210
211
212
213
214
        self,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite: bool = False,
215
    ):
thomwolf's avatar
thomwolf committed
216
217
        self.output_path = output_path
        self.input_path = input_path
218
        self.column = column.split(",") if column is not None else [""]
219
220
221
        self.is_multi_columns = len(self.column) > 1

        if self.is_multi_columns:
222
            self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column]
223

thomwolf's avatar
thomwolf committed
224
        if output_path is not None and not overwrite:
thomwolf's avatar
thomwolf committed
225
            if exists(abspath(self.output_path)):
226
                raise OSError("{} already exists on disk".format(self.output_path))
227

thomwolf's avatar
thomwolf committed
228
229
        if input_path is not None:
            if not exists(abspath(self.input_path)):
230
                raise OSError("{} doesnt exist on disk".format(self.input_path))
231
232
233
234
235
236

    @abstractmethod
    def __iter__(self):
        raise NotImplementedError()

    @abstractmethod
Sylvain Gugger's avatar
Sylvain Gugger committed
237
    def save(self, data: Union[dict, List[dict]]):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
238
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
239
240
241
242
243
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
244
        """
245
246
        raise NotImplementedError()

247
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
248
249
        """
        Save the provided data object as a pickle-formatted binary data on the disk.
Sylvain Gugger's avatar
Sylvain Gugger committed
250
251
252
253
254
255

        Args:
            data (:obj:`dict` or list of :obj:`dict`): The data to store.

        Returns:
            :obj:`str`: Path where the data has been saved.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
256
        """
thomwolf's avatar
thomwolf committed
257
        path, _ = os.path.splitext(self.output_path)
258
        binary_path = os.path.extsep.join((path, "pickle"))
259

260
        with open(binary_path, "wb+") as f_output:
261
262
263
264
            pickle.dump(data, f_output)

        return binary_path

265
    @staticmethod
266
    def from_str(
Lysandre's avatar
Lysandre committed
267
268
269
270
271
        format: str,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite=False,
Sylvain Gugger's avatar
Sylvain Gugger committed
272
273
    ) -> "PipelineDataFormat":
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
274
275
        Creates an instance of the right subclass of :class:`~transformers.pipelines.PipelineDataFormat` depending on
        :obj:`format`.
Sylvain Gugger's avatar
Sylvain Gugger committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

        Args:
            format: (:obj:`str`):
                The format of the desired pipeline. Acceptable values are :obj:`"json"`, :obj:`"csv"` or :obj:`"pipe"`.
            output_path (:obj:`str`, `optional`):
                Where to save the outgoing data.
            input_path (:obj:`str`, `optional`):
                Where to look for the input data.
            column (:obj:`str`, `optional`):
                The column to read.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to overwrite the :obj:`output_path`.

        Returns:
            :class:`~transformers.pipelines.PipelineDataFormat`: The proper data format.
        """
292
        if format == "json":
thomwolf's avatar
thomwolf committed
293
            return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
294
        elif format == "csv":
thomwolf's avatar
thomwolf committed
295
            return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
296
        elif format == "pipe":
thomwolf's avatar
thomwolf committed
297
            return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
298
        else:
299
            raise KeyError("Unknown reader {} (Available reader are json/csv/pipe)".format(format))
300
301
302


class CsvPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
303
304
305
306
307
308
309
310
311
312
313
    """
    Support for pipelines using CSV data format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

314
    def __init__(
Lysandre's avatar
Lysandre committed
315
316
317
318
319
        self,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite=False,
320
    ):
thomwolf's avatar
thomwolf committed
321
        super().__init__(output_path, input_path, column, overwrite=overwrite)
322
323

    def __iter__(self):
324
        with open(self.input_path, "r") as f:
325
326
327
328
329
            reader = csv.DictReader(f)
            for row in reader:
                if self.is_multi_columns:
                    yield {k: row[c] for k, c in self.column}
                else:
330
                    yield row[self.column[0]]
331
332

    def save(self, data: List[dict]):
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
335
336
337
338
339
        """
        Save the provided data object with the representation for the current
        :class:`~transformers.pipelines.PipelineDataFormat`.

        Args:
            data (:obj:`List[dict]`): The data to store.
        """
340
        with open(self.output_path, "w") as f:
341
342
343
344
345
346
347
            if len(data) > 0:
                writer = csv.DictWriter(f, list(data[0].keys()))
                writer.writeheader()
                writer.writerows(data)


class JsonPipelineDataFormat(PipelineDataFormat):
Sylvain Gugger's avatar
Sylvain Gugger committed
348
349
350
351
352
353
354
355
356
357
358
    """
    Support for pipelines using JSON file format.

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
    """

359
    def __init__(
Lysandre's avatar
Lysandre committed
360
361
362
363
364
        self,
        output_path: Optional[str],
        input_path: Optional[str],
        column: Optional[str],
        overwrite=False,
365
    ):
thomwolf's avatar
thomwolf committed
366
        super().__init__(output_path, input_path, column, overwrite=overwrite)
367

368
        with open(input_path, "r") as f:
369
370
371
372
373
374
375
            self._entries = json.load(f)

    def __iter__(self):
        for entry in self._entries:
            if self.is_multi_columns:
                yield {k: entry[c] for k, c in self.column}
            else:
376
                yield entry[self.column[0]]
377
378

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
379
380
381
382
383
384
        """
        Save the provided data object in a json file.

        Args:
            data (:obj:`dict`): The data to store.
        """
385
        with open(self.output_path, "w") as f:
386
387
388
            json.dump(data, f)


Morgan Funtowicz's avatar
Morgan Funtowicz committed
389
390
class PipedPipelineDataFormat(PipelineDataFormat):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
391
    Read data from piped input to the python process. For multi columns data, columns should separated by \t
Morgan Funtowicz's avatar
Morgan Funtowicz committed
392
393

    If columns are provided, then the output will be a dictionary with {column_x: value_x}
Sylvain Gugger's avatar
Sylvain Gugger committed
394
395
396
397
398
399
400

    Args:
        output_path (:obj:`str`, `optional`): Where to save the outgoing data.
        input_path (:obj:`str`, `optional`): Where to look for the input data.
        column (:obj:`str`, `optional`): The column to read.
        overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to overwrite the :obj:`output_path`.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
401
    """
402

Morgan Funtowicz's avatar
Morgan Funtowicz committed
403
404
405
    def __iter__(self):
        for line in sys.stdin:
            # Split for multi-columns
406
            if "\t" in line:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
407

408
                line = line.split("\t")
Morgan Funtowicz's avatar
Morgan Funtowicz committed
409
410
411
412
413
414
415
416
417
418
419
                if self.column:
                    # Dictionary to map arguments
                    yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
                else:
                    yield tuple(line)

            # No dictionary to map arguments
            else:
                yield line

    def save(self, data: dict):
Sylvain Gugger's avatar
Sylvain Gugger committed
420
421
422
423
424
425
        """
        Print the data.

        Args:
            data (:obj:`dict`): The data to store.
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
426
427
        print(data)

428
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
thomwolf's avatar
thomwolf committed
429
        if self.output_path is None:
430
            raise KeyError(
431
432
                "When using piped input on pipeline outputting large object requires an output file path. "
                "Please provide such output path through --output argument."
433
434
435
436
            )

        return super().save_binary(data)

Morgan Funtowicz's avatar
Morgan Funtowicz committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451

class _ScikitCompat(ABC):
    """
    Interface layer for the Scikit and Keras compatibility.
    """

    @abstractmethod
    def transform(self, X):
        raise NotImplementedError()

    @abstractmethod
    def predict(self, X):
        raise NotImplementedError()


Sylvain Gugger's avatar
Sylvain Gugger committed
452
PIPELINE_INIT_ARGS = r"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
453
    Arguments:
454
455
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
456
457
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
458
459
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
460
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
461
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
462
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
463
464
465
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
466

Sylvain Gugger's avatar
Sylvain Gugger committed
467
468
469
            If no framework is specified, will default to the one currently installed. If no framework is specified and
            both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
            is provided.
Sylvain Gugger's avatar
Sylvain Gugger committed
470
471
472
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
473
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
474
        device (:obj:`int`, `optional`, defaults to -1):
Sylvain Gugger's avatar
Sylvain Gugger committed
475
476
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
            the associated CUDA device id.
Lysandre Debut's avatar
Lysandre Debut committed
477
        binary_output (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
478
479
            Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
Morgan Funtowicz's avatar
Morgan Funtowicz committed
480

Lysandre Debut's avatar
Lysandre Debut committed
481

Sylvain Gugger's avatar
Sylvain Gugger committed
482
483
484
485
486
487
@add_end_docstrings(PIPELINE_INIT_ARGS)
class Pipeline(_ScikitCompat):
    """
    The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
    different pipelines.

Sylvain Gugger's avatar
Sylvain Gugger committed
488
489
    Base class implementing pipelined operations. Pipeline workflow is defined as a sequence of the following
    operations:
Sylvain Gugger's avatar
Sylvain Gugger committed
490
491
492
493
494
495
496
497
498

        Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output

    Pipeline supports running on CPU or GPU through the device argument (see below).

    Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` )
    output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
    provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the
    pickle format.
499
    """
thomwolf's avatar
thomwolf committed
500
501
502

    default_input_names = None

503
504
    def __init__(
        self,
505
506
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
507
        modelcard: Optional[ModelCard] = None,
508
        framework: Optional[str] = None,
509
        task: str = "",
510
511
512
513
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
    ):
514

thomwolf's avatar
thomwolf committed
515
        if framework is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
516
            framework = get_framework(model)
thomwolf's avatar
thomwolf committed
517

518
        self.task = task
519
520
        self.model = model
        self.tokenizer = tokenizer
521
        self.modelcard = modelcard
thomwolf's avatar
thomwolf committed
522
        self.framework = framework
523
        self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
524
        self.binary_output = binary_output
525

526
        # Special handling
527
528
        if self.framework == "pt" and self.device.type == "cuda":
            self.model = self.model.to(self.device)
529

530
531
532
533
534
        # Update config with task specific parameters
        task_specific_params = self.model.config.task_specific_params
        if task_specific_params is not None and task in task_specific_params:
            self.model.config.update(task_specific_params.get(task))

Sylvain Gugger's avatar
Sylvain Gugger committed
535
    def save_pretrained(self, save_directory: str):
536
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
537
538
539
540
541
        Save the pipeline's model and tokenizer.

        Args:
            save_directory (:obj:`str`):
                A path to the directory where to saved. It will be created if it doesn't exist.
542
        """
543
544
        if os.path.isfile(save_directory):
            logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
545
            return
546
        os.makedirs(save_directory, exist_ok=True)
547
548
549

        self.model.save_pretrained(save_directory)
        self.tokenizer.save_pretrained(save_directory)
550
551
        if self.modelcard is not None:
            self.modelcard.save_pretrained(save_directory)
552
553

    def transform(self, X):
554
555
556
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
557
558
559
        return self(X=X)

    def predict(self, X):
560
561
562
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
563
        return self(X=X)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
564

565
566
    @contextmanager
    def device_placement(self):
567
568
        """
        Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
Sylvain Gugger's avatar
Sylvain Gugger committed
569

570
571
        Returns:
            Context manager
Sylvain Gugger's avatar
Sylvain Gugger committed
572
573
574
575
576
577
578
579

        Examples::

            # Explicitly ask for tensor allocation on CUDA device :0
            pipe = pipeline(..., device=0)
            with pipe.device_placement():
                # Every framework specific tensor allocation will be done on the request device
                output = pipe(...)
580
        """
581
582
        if self.framework == "tf":
            with tf.device("/CPU:0" if self.device == -1 else "/device:GPU:{}".format(self.device)):
583
584
                yield
        else:
585
            if self.device.type == "cuda":
586
                torch.cuda.set_device(self.device)
587

588
            yield
589

590
591
592
    def ensure_tensor_on_device(self, **inputs):
        """
        Ensure PyTorch tensors are on the specified device.
Sylvain Gugger's avatar
Sylvain Gugger committed
593
594
595
596
597
598

        Args:
            inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.

        Return:
            :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
599
600
601
        """
        return {name: tensor.to(self.device) for name, tensor in inputs.items()}

Sylvain Gugger's avatar
Sylvain Gugger committed
602
    def check_model_type(self, supported_models: Union[List[str], dict]):
603
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
604
605
606
607
608
        Check if the model class is in supported by the pipeline.

        Args:
            supported_models (:obj:`List[str]` or :obj:`dict`):
                The list of models supported by the pipeline, or a dictionary with model class values.
609
610
611
612
613
614
615
616
617
618
        """
        if not isinstance(supported_models, list):  # Create from a model mapping
            supported_models = [item[1].__name__ for item in supported_models.items()]
        if self.model.__class__.__name__ not in supported_models:
            raise PipelineException(
                self.task,
                self.model.base_model_prefix,
                f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
            )

619
    def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
620
621
622
        """
        Parse arguments and tokenize
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
623
        # Parse arguments
624
        inputs = self.tokenizer(
Lysandre's avatar
Lysandre committed
625
626
627
628
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
629
        )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
630

Julien Chaumond's avatar
Julien Chaumond committed
631
632
        return inputs

633
634
    def __call__(self, *args, **kwargs):
        inputs = self._parse_and_tokenize(*args, **kwargs)
635
        return self._forward(inputs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
636

Julien Chaumond's avatar
Julien Chaumond committed
637
    def _forward(self, inputs, return_tensors=False):
638
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
639
640
        Internal framework specific forward dispatching

641
        Args:
642
            inputs: dict holding all the keyword arguments for required by the model forward method.
Sylvain Gugger's avatar
Sylvain Gugger committed
643
644
            return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array

645
646
647
        Returns:
            Numpy array
        """
648
649
650
651
        # Encode for forward
        with self.device_placement():
            if self.framework == "tf":
                # TODO trace model
Funtowicz Morgan's avatar
Funtowicz Morgan committed
652
                predictions = self.model(inputs.data, training=False)[0]
653
654
655
656
            else:
                with torch.no_grad():
                    inputs = self.ensure_tensor_on_device(**inputs)
                    predictions = self.model(**inputs)[0].cpu()
657

Julien Chaumond's avatar
Julien Chaumond committed
658
659
660
661
        if return_tensors:
            return predictions
        else:
            return predictions.numpy()
662
663


Sylvain Gugger's avatar
Sylvain Gugger committed
664
# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output`
665
class FeatureExtractionPipeline(Pipeline):
666
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
667
668
    Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
    transformer, which can be used as features in downstream tasks.
Lysandre Debut's avatar
Lysandre Debut committed
669

Sylvain Gugger's avatar
Sylvain Gugger committed
670
671
    This feature extraction pipeline can currently be loaded from :func:`~transformers.pipeline` using the task
    identifier: :obj:`"feature-extraction"`.
Lysandre Debut's avatar
Lysandre Debut committed
672
673
674
675
676

    All models may be used for this pipeline. See a list of all models, including community-contributed models on
    `huggingface.co/models <https://huggingface.co/models>`__.

    Arguments:
677
678
        model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
679
680
            :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
            TensorFlow.
681
682
        tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
Lysandre Debut's avatar
Lysandre Debut committed
683
            :class:`~transformers.PreTrainedTokenizer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
684
        modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
685
            Model card attributed to the model for this pipeline.
Sylvain Gugger's avatar
Sylvain Gugger committed
686
687
688
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
689

Sylvain Gugger's avatar
Sylvain Gugger committed
690
691
692
            If no framework is specified, will default to the one currently installed. If no framework is specified and
            both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
            is provided.
Sylvain Gugger's avatar
Sylvain Gugger committed
693
694
695
        task (:obj:`str`, defaults to :obj:`""`):
            A task-identifier for the pipeline.
        args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Lysandre Debut's avatar
Lysandre Debut committed
696
            Reference to the object in charge of parsing supplied pipeline parameters.
Sylvain Gugger's avatar
Sylvain Gugger committed
697
        device (:obj:`int`, `optional`, defaults to -1):
Sylvain Gugger's avatar
Sylvain Gugger committed
698
699
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
            the associated CUDA device id.
700
    """
701

702
703
    def __init__(
        self,
704
705
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
706
        modelcard: Optional[ModelCard] = None,
707
708
709
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
710
        task: str = "",
711
712
713
714
715
716
717
718
719
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
720
            task=task,
721
        )
722

723
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
724
725
726
727
728
729
730
731
732
        """
        Extract the features of the input(s).

        Args:
            args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of texts) to get the features of.

        Return:
            A nested list of :obj:`float`: The features computed by the model.
        """
733
        return super().__call__(*args, **kwargs).tolist()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
734
735


Sylvain Gugger's avatar
Sylvain Gugger committed
736
@add_end_docstrings(PIPELINE_INIT_ARGS)
737
738
class TextGenerationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
739
740
    Language generation pipeline using any :obj:`ModelWithLMHead`. This pipeline predicts the words that will follow a
    specified text prompt.
741

Sylvain Gugger's avatar
Sylvain Gugger committed
742
743
    This language generation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"text-generation"`.
744

Sylvain Gugger's avatar
Sylvain Gugger committed
745
    The models that this pipeline can use are models that have been trained with an autoregressive language modeling
Sylvain Gugger's avatar
Sylvain Gugger committed
746
747
    objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available
    community models on `huggingface.co/models <https://huggingface.co/models?filter=causal-lm>`__.
748
749
    """

750
    # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
751
752
    # in https://github.com/rusiaaman/XLNet-gen#methodology
    # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
753

Sylvain Gugger's avatar
Sylvain Gugger committed
754
755
756
757
758
759
760
761
762
    XL_PREFIX = """
    In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The
    voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western
    Siberia, a young Grigori Rasputin is asked by his father and a group of men to perform magic. Rasputin has a vision
    and denounces one of the men as a horse thief. Although his father initially slaps him for making such an
    accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
    the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop,
    begging for his blessing. <eod> </s> <eos>
    """
763

764
765
766
767
768
769
770
771
772
773
774
775
776
777
    ALLOWED_MODELS = [
        "XLNetLMHeadModel",
        "TransfoXLLMHeadModel",
        "ReformerModelWithLMHead",
        "GPT2LMHeadModel",
        "OpenAIGPTLMHeadModel",
        "CTRLLMHeadModel",
        "TFXLNetLMHeadModel",
        "TFTransfoXLLMHeadModel",
        "TFGPT2LMHeadModel",
        "TFOpenAIGPTLMHeadModel",
        "TFCTRLLMHeadModel",
    ]

778
779
780
781
782
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.check_model_type(self.ALLOWED_MODELS)

783
784
    # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments

785
    def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
        """
        Parse arguments and tokenize
        """
        # Parse arguments
        if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
            tokenizer_kwargs = {"add_space_before_punct_symbol": True}
        else:
            tokenizer_kwargs = {}
        inputs = self.tokenizer(
            inputs,
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            **tokenizer_kwargs,
        )

        return inputs

804
    def __call__(
805
        self,
806
        text_inputs,
807
808
809
810
811
        return_tensors=False,
        return_text=True,
        clean_up_tokenization_spaces=False,
        prefix=None,
        **generate_kwargs
812
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
813
814
815
816
817
818
819
        """
        Complete the prompt(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several prompts (or one list of prompts) to complete.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
820
                Whether or not to include the tensors of predictions (as token indices) in the outputs.
Sylvain Gugger's avatar
Sylvain Gugger committed
821
822
823
824
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
825
826
            prefix (:obj:`str`, `optional`):
                Prefix added to prompt.
Sylvain Gugger's avatar
Sylvain Gugger committed
827
            generate_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
828
829
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
                corresponding to your framework `here <./model.html#generative-models>`__).
Sylvain Gugger's avatar
Sylvain Gugger committed
830
831

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
832
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
833

Sylvain Gugger's avatar
Sylvain Gugger committed
834
835
836
837
            - **generated_text** (:obj:`str`, present when ``return_text=True``) -- The generated text.
            - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the generated text.
        """
838
839
840
841
842

        results = []
        for prompt_text in text_inputs:
            # Manage correct placement of the tensors
            with self.device_placement():
843
844
                prefix = prefix if prefix is not None else self.model.config.prefix
                if prefix is None and self.model.__class__.__name__ in [
845
846
847
848
849
                    "XLNetLMHeadModel",
                    "TransfoXLLMHeadModel",
                    "TFXLNetLMHeadModel",
                    "TFTransfoXLLMHeadModel",
                ]:
850
851
852
853
854
                    # For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
                    prefix = self.XL_PREFIX

                if prefix:
                    prefix_inputs = self._parse_and_tokenize(prefix, padding=False, add_special_tokens=False)
855
                    # This impacts max_length and min_length argument that need adjusting.
856
857
858
859
860
861
862
863
                    prefix_length = prefix_inputs["input_ids"].shape[-1]
                    if generate_kwargs.get("max_length", None) is not None:
                        generate_kwargs["max_length"] += prefix_length
                    if generate_kwargs.get("min_length", None) is not None:
                        generate_kwargs["min_length"] += prefix_length

                prefix = prefix or ""
                inputs = self._parse_and_tokenize(prefix + prompt_text, padding=False, add_special_tokens=False)
864

865
866
867
868
869
870
                # set input_ids to None to allow empty prompt
                if inputs["input_ids"].shape[-1] == 0:
                    inputs["input_ids"] = None
                    inputs["attention_mask"] = None

                if self.framework == "pt" and inputs["input_ids"] is not None:
871
872
873
874
875
876
                    inputs = self.ensure_tensor_on_device(**inputs)

                input_ids = inputs["input_ids"]

                # Ensure that batch size = 1 (batch generation not allowed for now)
                assert (
877
                    input_ids is None or input_ids.shape[0] == 1
878
879
880
881
882
883
                ), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."

                output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs)  # BS x SL

            result = []
            for generated_sequence in output_sequences:
884
885
                if self.framework == "pt" and generated_sequence is not None:
                    generated_sequence = generated_sequence.cpu()
886
                generated_sequence = generated_sequence.numpy().tolist()
887
888
889
890
891
892
893
894
895
896
897
898
                record = {}
                if return_tensors:
                    record["generated_token_ids"] = generated_sequence
                if return_text:
                    # Decode text
                    text = self.tokenizer.decode(
                        generated_sequence,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )

                    # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
899
900
901
902
903
904
905
906
907
908
909
910
                    if input_ids is None:
                        prompt_length = 0
                    else:
                        prompt_length = len(
                            self.tokenizer.decode(
                                input_ids[0],
                                skip_special_tokens=True,
                                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                            )
                        )

                    record["generated_text"] = prompt_text + text[prompt_length:]
911
912
913
914
915
916
917
918
919
920

                result.append(record)
            results += [result]

        if len(results) == 1:
            return results[0]

        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
921
922
923
924
925
926
927
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        return_all_scores (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether to return all prediction scores or just the one of the predicted class.
    """,
)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
928
class TextClassificationPipeline(Pipeline):
929
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
930
931
    Text classification pipeline using any :obj:`ModelForSequenceClassification`. See the `sequence classification
    examples <../task_summary.html#sequence-classification>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
932

Sylvain Gugger's avatar
Sylvain Gugger committed
933
934
935
    This text classification pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"sentiment-analysis"` (for classifying sequences according to positive or negative
    sentiments).
Lysandre Debut's avatar
Lysandre Debut committed
936

Sylvain Gugger's avatar
Sylvain Gugger committed
937
938
    If multiple classification labels are available (:obj:`model.config.num_labels >= 2`), the pipeline will run a
    softmax over the results. If there is a single label, the pipeline will run a sigmoid over the result.
939

Sylvain Gugger's avatar
Sylvain Gugger committed
940
941
942
    The models that this pipeline can use are models that have been fine-tuned on a sequence classification task. See
    the up-to-date list of available models on `huggingface.co/models
    <https://huggingface.co/models?filter=text-classification>`__.
943
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
944

945
946
947
    def __init__(self, return_all_scores: bool = False, **kwargs):
        super().__init__(**kwargs)

948
949
950
951
952
953
        self.check_model_type(
            TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
        )

954
955
        self.return_all_scores = return_all_scores

956
    def __call__(self, *args, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
957
958
959
960
961
        """
        Classify the text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
Yuta Hayashibe's avatar
Yuta Hayashibe committed
962
                One or several texts (or one list of prompts) to classify.
Sylvain Gugger's avatar
Sylvain Gugger committed
963
964

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
965
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
966
967
968
969
970
971

            - **label** (:obj:`str`) -- The label predicted.
            - **score** (:obj:`float`) -- The corresponding probability.

            If ``self.return_all_scores=True``, one such dictionary is returned per label.
        """
972
        outputs = super().__call__(*args, **kwargs)
973
974
975
976
977

        if self.model.config.num_labels == 1:
            scores = 1.0 / (1.0 + np.exp(-outputs))
        else:
            scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
978
979
        if self.return_all_scores:
            return [
980
                [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(item)]
981
982
983
984
985
986
                for item in scores
            ]
        else:
            return [
                {"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores
            ]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
987
988


989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
class ZeroShotClassificationArgumentHandler(ArgumentHandler):
    """
    Handles arguments for zero-shot for text classification by turning each possible label into an NLI
    premise/hypothesis pair.
    """

    def _parse_labels(self, labels):
        if isinstance(labels, str):
            labels = [label.strip() for label in labels.split(",")]
        return labels

    def __call__(self, sequences, labels, hypothesis_template):
        if len(labels) == 0 or len(sequences) == 0:
            raise ValueError("You must include at least one label and at least one sequence.")
        if hypothesis_template.format(labels[0]) == hypothesis_template:
            raise ValueError(
                (
                    'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
                    "Make sure the passed template includes formatting syntax such as {{}} where the label should go."
                ).format(hypothesis_template)
            )

        if isinstance(sequences, str):
            sequences = [sequences]
        labels = self._parse_labels(labels)

        sequence_pairs = []
        for sequence in sequences:
            sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])

        return sequence_pairs


Sylvain Gugger's avatar
Sylvain Gugger committed
1022
@add_end_docstrings(PIPELINE_INIT_ARGS)
1023
1024
class ZeroShotClassificationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1025
1026
    NLI-based zero-shot classification pipeline using a :obj:`ModelForSequenceClassification` trained on NLI (natural
    language inference) tasks.
1027
1028

    Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
Sylvain Gugger's avatar
Sylvain Gugger committed
1029
    pair and passed to the pretrained model. Then, the logit for `entailment` is taken as the logit for the candidate
1030
1031
    label being valid. Any NLI model can be used, but the id of the `entailment` label must be included in the model
    config's :attr:`~transformers.PretrainedConfig.label2id`.
1032

Sylvain Gugger's avatar
Sylvain Gugger committed
1033
1034
    This NLI pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task identifier:
    :obj:`"zero-shot-classification"`.
1035

Sylvain Gugger's avatar
Sylvain Gugger committed
1036
1037
    The models that this pipeline can use are models that have been fine-tuned on an NLI task. See the up-to-date list
    of available models on `huggingface.co/models <https://huggingface.co/models?search=nli>`__.
1038
1039
1040
    """

    def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
1041
1042
        super().__init__(*args, **kwargs)
        self._args_parser = args_parser
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        if self.entailment_id == -1:
            logger.warning(
                "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
                "-1. Define a descriptive label2id mapping in the model config to ensure correct outputs."
            )

    @property
    def entailment_id(self):
        for label, ind in self.model.config.label2id.items():
            if label.lower().startswith("entail"):
                return ind
        return -1
1055

1056
1057
1058
    def _parse_and_tokenize(
        self, sequences, candidal_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
    ):
1059
1060
1061
        """
        Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
        """
1062
        sequence_pairs = self._args_parser(sequences, candidal_labels, hypothesis_template)
1063
        inputs = self.tokenizer(
1064
            sequence_pairs,
1065
1066
1067
1068
1069
1070
1071
1072
            add_special_tokens=add_special_tokens,
            return_tensors=self.framework,
            padding=padding,
            truncation="only_first",
        )

        return inputs

1073
1074
1075
1076
1077
1078
1079
    def __call__(
        self,
        sequences: Union[str, List[str]],
        candidate_labels,
        hypothesis_template="This example is {}.",
        multi_class=False,
    ):
1080
        """
1081
1082
        Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline`
        documentation for more information.
1083
1084

        Args:
1085
            sequences (:obj:`str` or :obj:`List[str]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1086
                The sequence(s) to classify, will be truncated if the model input is too large.
1087
            candidate_labels (:obj:`str` or :obj:`List[str]`):
1088
1089
                The set of possible class labels to classify each sequence into. Can be a single label, a string of
                comma-separated labels, or a list of labels.
1090
            hypothesis_template (:obj:`str`, `optional`, defaults to :obj:`"This example is {}."`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1091
1092
                The template used to turn each label into an NLI-style hypothesis. This template must include a {} or
                similar syntax for the candidate label to be inserted into the template. For example, the default
Sylvain Gugger's avatar
Sylvain Gugger committed
1093
1094
1095
1096
                template is :obj:`"This example is {}."` With the candidate label :obj:`"sports"`, this would be fed
                into the model like :obj:`"<cls> sequence to classify <sep> This example is sports . <sep>"`. The
                default template works well in many cases, but it may be worthwhile to experiment with different
                templates depending on the task setting.
1097
            multi_class (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1098
1099
1100
1101
                Whether or not multiple candidate labels can be true. If :obj:`False`, the scores are normalized such
                that the sum of the label likelihoods for each sequence is 1. If :obj:`True`, the labels are considered
                independent and probabilities are normalized for each candidate by doing a softmax of the entailment
                score vs. the contradiction score.
1102

Sylvain Gugger's avatar
Sylvain Gugger committed
1103
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1104
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
1105
1106
1107

            - **sequence** (:obj:`str`) -- The sequence for which this is the output.
            - **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
1108
            - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
1109
        """
1110
1111
1112
        if sequences and isinstance(sequences, str):
            sequences = [sequences]

1113
        outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
1114
        num_sequences = len(sequences)
1115
1116
1117
1118
1119
1120
1121
1122
        candidate_labels = self._args_parser._parse_labels(candidate_labels)
        reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))

        if len(candidate_labels) == 1:
            multi_class = True

        if not multi_class:
            # softmax the "entailment" logits over all candidate labels
1123
            entail_logits = reshaped_outputs[..., self.entailment_id]
1124
1125
1126
            scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
        else:
            # softmax over the entailment vs. contradiction dim for each label independently
1127
1128
1129
            entailment_id = self.entailment_id
            contradiction_id = -1 if entailment_id == 0 else 0
            entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
1130
1131
1132
1133
1134
1135
1136
1137
            scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
            scores = scores[..., 1]

        result = []
        for iseq in range(num_sequences):
            top_inds = list(reversed(scores[iseq].argsort()))
            result.append(
                {
1138
                    "sequence": sequences if isinstance(sequences, str) else sequences[iseq],
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
                    "labels": [candidate_labels[i] for i in top_inds],
                    "scores": scores[iseq][top_inds].tolist(),
                }
            )

        if len(result) == 1:
            return result[0]
        return result


Sylvain Gugger's avatar
Sylvain Gugger committed
1149
1150
1151
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
1152
        top_k (:obj:`int`, defaults to 5): The number of predictions to return.
Sylvain Gugger's avatar
Sylvain Gugger committed
1153
1154
    """,
)
Julien Chaumond's avatar
Julien Chaumond committed
1155
1156
class FillMaskPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1157
1158
    Masked language modeling prediction pipeline using any :obj:`ModelWithLMHead`. See the `masked language modeling
    examples <../task_summary.html#masked-language-modeling>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1159

Sylvain Gugger's avatar
Sylvain Gugger committed
1160
1161
    This mask filling pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
    identifier: :obj:`"fill-mask"`.
Lysandre Debut's avatar
Lysandre Debut committed
1162
1163

    The models that this pipeline can use are models that have been trained with a masked language modeling objective,
Sylvain Gugger's avatar
Sylvain Gugger committed
1164
    which includes the bi-directional models in the library. See the up-to-date list of available models on
1165
    `huggingface.co/models <https://huggingface.co/models?filter=masked-lm>`__.
Lysandre Debut's avatar
Lysandre Debut committed
1166

Sylvain Gugger's avatar
Sylvain Gugger committed
1167
    .. note::
Lysandre Debut's avatar
Lysandre Debut committed
1168

Sylvain Gugger's avatar
Sylvain Gugger committed
1169
        This pipeline only works for inputs with exactly one token masked.
Julien Chaumond's avatar
Julien Chaumond committed
1170
1171
1172
1173
    """

    def __init__(
        self,
1174
1175
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1176
        modelcard: Optional[ModelCard] = None,
Julien Chaumond's avatar
Julien Chaumond committed
1177
1178
1179
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
1180
        top_k=5,
1181
        task: str = "",
1182
        **kwargs
Julien Chaumond's avatar
Julien Chaumond committed
1183
1184
1185
1186
1187
1188
1189
1190
1191
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
1192
            task=task,
Julien Chaumond's avatar
Julien Chaumond committed
1193
1194
        )

1195
        self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
1196

1197
1198
1199
1200
1201
1202
1203
1204
        if "topk" in kwargs:
            warnings.warn(
                "The `topk` argument is deprecated and will be removed in a future version, use `top_k` instead.",
                FutureWarning,
            )
            self.top_k = kwargs.pop("topk")
        else:
            self.top_k = top_k
Julien Chaumond's avatar
Julien Chaumond committed
1205

1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
    def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
        numel = np.prod(masked_index.shape)
        if numel > 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
            )
        elif numel < 1:
            raise PipelineException(
                "fill-mask",
                self.model.base_model_prefix,
                f"No mask_token ({self.tokenizer.mask_token}) found on the input",
            )

1221
    def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1222
1223
1224
1225
        """
        Fill the masked token in the text(s) given as inputs.

        Args:
1226
1227
1228
1229
            args (:obj:`str` or :obj:`List[str]`):
                One or several texts (or one list of prompts) with masked tokens.
            targets (:obj:`str` or :obj:`List[str]`, `optional`):
                When passed, the model will return the scores for the passed token or tokens rather than the top k
Sylvain Gugger's avatar
Sylvain Gugger committed
1230
1231
                predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will be
                tokenized and the first resulting token will be used (with a warning).
1232
1233
            top_k (:obj:`int`, `optional`):
                When passed, overrides the number of predictions to return.
Sylvain Gugger's avatar
Sylvain Gugger committed
1234
1235

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1236
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
1237
1238
1239
1240
1241
1242

            - **sequence** (:obj:`str`) -- The corresponding input with the mask token prediction.
            - **score** (:obj:`float`) -- The corresponding probability.
            - **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
            - **token** (:obj:`str`) -- The predicted token (to replace the masked one).
        """
Julien Chaumond's avatar
Julien Chaumond committed
1243
1244
1245
1246
1247
1248
        inputs = self._parse_and_tokenize(*args, **kwargs)
        outputs = self._forward(inputs, return_tensors=True)

        results = []
        batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)

1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
        if targets is not None:
            if len(targets) == 0 or len(targets[0]) == 0:
                raise ValueError("At least one target must be provided when passed.")
            if isinstance(targets, str):
                targets = [targets]

            targets_proc = []
            for target in targets:
                target_enc = self.tokenizer.tokenize(target)
                if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
                    logger.warning(
                        "The specified target token `{}` does not exist in the model vocabulary. Replacing with `{}`.".format(
                            target, target_enc[0]
                        )
                    )
                targets_proc.append(target_enc[0])
            target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))

Julien Chaumond's avatar
Julien Chaumond committed
1267
1268
1269
1270
1271
        for i in range(batch_size):
            input_ids = inputs["input_ids"][i]
            result = []

            if self.framework == "tf":
1272
1273
1274
1275
1276
1277
                masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index)

                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1278
                probs = tf.nn.softmax(logits)
1279
                if targets is None:
1280
                    topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
1281
1282
1283
1284
1285
1286
                    values, predictions = topk.values.numpy(), topk.indices.numpy()
                else:
                    values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
                    sort_inds = tf.reverse(tf.argsort(values), [0])
                    values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
                    predictions = target_inds[sort_inds.numpy()]
Julien Chaumond's avatar
Julien Chaumond committed
1287
            else:
1288
                masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
1289
1290
1291

                # Fill mask pipeline supports only one ${mask_token} per sample
                self.ensure_exactly_one_mask_token(masked_index.numpy())
1292

1293
                logits = outputs[i, masked_index.item(), :]
Julien Chaumond's avatar
Julien Chaumond committed
1294
                probs = logits.softmax(dim=0)
1295
                if targets is None:
1296
                    values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
1297
1298
1299
1300
1301
                else:
                    values = probs[..., target_inds]
                    sort_inds = list(reversed(values.argsort(dim=-1)))
                    values = values[..., sort_inds]
                    predictions = target_inds[sort_inds]
Julien Chaumond's avatar
Julien Chaumond committed
1302
1303
1304
1305
1306
1307

            for v, p in zip(values.tolist(), predictions.tolist()):
                tokens = input_ids.numpy()
                tokens[masked_index] = p
                # Filter padding out:
                tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
1308
1309
1310
1311
1312
1313
1314
1315
                result.append(
                    {
                        "sequence": self.tokenizer.decode(tokens),
                        "score": v,
                        "token": p,
                        "token_str": self.tokenizer.convert_ids_to_tokens(p),
                    }
                )
Julien Chaumond's avatar
Julien Chaumond committed
1316
1317
1318
1319
1320
1321
1322
1323
1324

            # Append
            results += [result]

        if len(results) == 1:
            return results[0]
        return results


Sylvain Gugger's avatar
Sylvain Gugger committed
1325
1326
1327
1328
1329
1330
1331
1332
1333
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`):
            A list of labels to ignore.
        grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to group the tokens corresponding to the same entity together in the predictions or not.
    """,
)
1334
class TokenClassificationPipeline(Pipeline):
1335
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1336
1337
    Named Entity Recognition pipeline using any :obj:`ModelForTokenClassification`. See the `named entity recognition
    examples <../task_summary.html#named-entity-recognition>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1338

Sylvain Gugger's avatar
Sylvain Gugger committed
1339
1340
1341
    This token recognition pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location
    or miscellaneous).
Lysandre Debut's avatar
Lysandre Debut committed
1342

Sylvain Gugger's avatar
Sylvain Gugger committed
1343
1344
1345
    The models that this pipeline can use are models that have been fine-tuned on a token classification task. See the
    up-to-date list of available models on `huggingface.co/models
    <https://huggingface.co/models?filter=token-classification>`__.
1346
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1347

1348
1349
1350
1351
    default_input_names = "sequences"

    def __init__(
        self,
1352
1353
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1354
        modelcard: Optional[ModelCard] = None,
1355
1356
1357
1358
1359
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
        ignore_labels=["O"],
1360
        task: str = "",
1361
        grouped_entities: bool = False,
1362
1363
1364
1365
1366
1367
1368
1369
1370
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=binary_output,
1371
            task=task,
1372
        )
1373

1374
1375
1376
1377
1378
1379
        self.check_model_type(
            TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
        )

1380
        self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
1381
        self.ignore_labels = ignore_labels
1382
        self.grouped_entities = grouped_entities
1383

1384
    def __call__(self, inputs: Union[str, List[str]], **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1385
1386
1387
1388
        """
        Classify each token of the text(s) given as inputs.

        Args:
1389
            inputs (:obj:`str` or :obj:`List[str]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
                One or several texts (or one list of texts) for token classification.

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in
            the corresponding input, or each entity if this pipeline was instantiated with
            :obj:`grouped_entities=True`) with the following keys:

            - **word** (:obj:`str`) -- The token/word classified.
            - **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
            - **entity** (:obj:`str`) -- The entity predicted for that token/word.
            - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
              corresponding token in the sentence.
        """
1403
1404
        if isinstance(inputs, str):
            inputs = [inputs]
Julien Chaumond's avatar
Julien Chaumond committed
1405
        answers = []
1406
        for sentence in inputs:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1407

1408
1409
            # Manage correct placement of the tensors
            with self.device_placement():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1410

1411
                tokens = self.tokenizer(
Lysandre's avatar
Lysandre committed
1412
1413
1414
1415
                    sentence,
                    return_attention_mask=False,
                    return_tensors=self.framework,
                    truncation=True,
1416
                )
1417
1418

                # Forward
1419
                if self.framework == "tf":
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1420
                    entities = self.model(tokens.data)[0][0].numpy()
1421
                    input_ids = tokens["input_ids"].numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1422
                else:
1423
                    with torch.no_grad():
1424
                        tokens = self.ensure_tensor_on_device(**tokens)
1425
                        entities = self.model(**tokens)[0][0].cpu().numpy()
1426
                        input_ids = tokens["input_ids"].cpu().numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1427

thomwolf's avatar
thomwolf committed
1428
1429
            score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
            labels_idx = score.argmax(axis=-1)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1430

1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
            entities = []
            # Filter to labels not in `self.ignore_labels`
            filtered_labels_idx = [
                (idx, label_idx)
                for idx, label_idx in enumerate(labels_idx)
                if self.model.config.id2label[label_idx] not in self.ignore_labels
            ]

            for idx, label_idx in filtered_labels_idx:

                entity = {
                    "word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
                    "score": score[idx][label_idx].item(),
                    "entity": self.model.config.id2label[label_idx],
                    "index": idx,
                }

                entities += [entity]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1449

1450
            # Append grouped entities
1451
            if self.grouped_entities:
1452
1453
                answers += [self.group_entities(entities)]
            # Append ungrouped entities
1454
1455
1456
            else:
                answers += [entities]

thomwolf's avatar
thomwolf committed
1457
1458
        if len(answers) == 1:
            return answers[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1459
1460
        return answers

1461
    def group_sub_entities(self, entities: List[dict]) -> dict:
1462
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1463
1464
1465
1466
        Group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1467
        """
1468
1469
        # Get the first entity in the entity group
        entity = entities[0]["entity"]
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
        scores = np.mean([entity["score"] for entity in entities])
        tokens = [entity["word"] for entity in entities]

        entity_group = {
            "entity_group": entity,
            "score": np.mean(scores),
            "word": self.tokenizer.convert_tokens_to_string(tokens),
        }
        return entity_group

1480
1481
    def group_entities(self, entities: List[dict]) -> List[dict]:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1482
1483
1484
1485
        Find and group together the adjacent tokens with the same entity predicted.

        Args:
            entities (:obj:`dict`): The entities predicted by the pipeline.
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
        """

        entity_groups = []
        entity_group_disagg = []

        if entities:
            last_idx = entities[-1]["index"]

        for entity in entities:
            is_last_idx = entity["index"] == last_idx
            if not entity_group_disagg:
                entity_group_disagg += [entity]
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
                continue

            # If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
            # The split is meant to account for the "B" and "I" suffixes
            if (
                entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
                and entity["index"] == entity_group_disagg[-1]["index"] + 1
            ):
                entity_group_disagg += [entity]
                # Group the entities at the last entity
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]
            # If the current entity is different from the previous entity, aggregate the disaggregated entity group
            else:
                entity_groups += [self.group_sub_entities(entity_group_disagg)]
                entity_group_disagg = [entity]
                # If it's the last entity, add it to the entity groups
                if is_last_idx:
                    entity_groups += [self.group_sub_entities(entity_group_disagg)]

        return entity_groups

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1522

1523
NerPipeline = TokenClassificationPipeline
1524
1525


1526
1527
class QuestionAnsweringArgumentHandler(ArgumentHandler):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1528
1529
    QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped to
    internal :class:`~transformers.SquadExample`.
1530

Sylvain Gugger's avatar
Sylvain Gugger committed
1531
1532
    QuestionAnsweringArgumentHandler manages all the possible to create a :class:`~transformers.SquadExample` from the
    command-line supplied arguments.
1533
    """
1534

1535
1536
1537
1538
    def __call__(self, *args, **kwargs):
        # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
        if args is not None and len(args) > 0:
            if len(args) == 1:
1539
                kwargs["X"] = args[0]
1540
            else:
1541
                kwargs["X"] = list(args)
1542

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1543
1544
        # Generic compatibility with sklearn and Keras
        # Batched data
1545
1546
        if "X" in kwargs or "data" in kwargs:
            inputs = kwargs["X"] if "X" in kwargs else kwargs["data"]
1547

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1548
1549
1550
1551
1552
            if isinstance(inputs, dict):
                inputs = [inputs]
            else:
                # Copy to avoid overriding arguments
                inputs = [i for i in inputs]
1553

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1554
            for i, item in enumerate(inputs):
1555
                if isinstance(item, dict):
1556
1557
                    if any(k not in item for k in ["question", "context"]):
                        raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
1558

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1559
1560
1561
                    inputs[i] = QuestionAnsweringPipeline.create_sample(**item)

                elif not isinstance(item, SquadExample):
1562
                    raise ValueError(
1563
1564
1565
                        "{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format(
                            "X" if "X" in kwargs else "data"
                        )
1566
1567
1568
                    )

            # Tabular input
1569
1570
1571
        elif "question" in kwargs and "context" in kwargs:
            if isinstance(kwargs["question"], str):
                kwargs["question"] = [kwargs["question"]]
1572

1573
1574
            if isinstance(kwargs["context"], str):
                kwargs["context"] = [kwargs["context"]]
1575

1576
1577
1578
            inputs = [
                QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"])
            ]
1579
        else:
1580
            raise ValueError("Unknown arguments {}".format(kwargs))
1581
1582
1583
1584
1585
1586
1587

        if not isinstance(inputs, list):
            inputs = [inputs]

        return inputs


Sylvain Gugger's avatar
Sylvain Gugger committed
1588
@add_end_docstrings(PIPELINE_INIT_ARGS)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1589
1590
class QuestionAnsweringPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1591
1592
    Question Answering pipeline using any :obj:`ModelForQuestionAnswering`. See the `question answering examples
    <../task_summary.html#question-answering>`__ for more information.
Lysandre Debut's avatar
Lysandre Debut committed
1593

Sylvain Gugger's avatar
Sylvain Gugger committed
1594
1595
    This question answering pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
    task identifier: :obj:`"question-answering"`.
Lysandre Debut's avatar
Lysandre Debut committed
1596

Sylvain Gugger's avatar
Sylvain Gugger committed
1597
1598
1599
    The models that this pipeline can use are models that have been fine-tuned on a question answering task. See the
    up-to-date list of available models on `huggingface.co/models
    <https://huggingface.co/models?filter=question-answering>`__.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1600
1601
    """

1602
1603
1604
1605
    default_input_names = "question,context"

    def __init__(
        self,
1606
1607
        model: Union["PreTrainedModel", "TFPreTrainedModel"],
        tokenizer: PreTrainedTokenizer,
1608
        modelcard: Optional[ModelCard] = None,
1609
1610
        framework: Optional[str] = None,
        device: int = -1,
1611
        task: str = "",
1612
1613
1614
1615
1616
1617
1618
1619
        **kwargs
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            device=device,
1620
            task=task,
1621
            **kwargs,
1622
        )
thomwolf's avatar
thomwolf committed
1623

1624
        self._args_parser = QuestionAnsweringArgumentHandler()
1625
1626
1627
1628
        self.check_model_type(
            TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
        )

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1629
    @staticmethod
1630
1631
1632
    def create_sample(
        question: Union[str, List[str]], context: Union[str, List[str]]
    ) -> Union[SquadExample, List[SquadExample]]:
1633
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1634
1635
        QuestionAnsweringPipeline leverages the :class:`~transformers.SquadExample` internally. This helper method
        encapsulate all the logic for converting question(s) and context(s) to :class:`~transformers.SquadExample`.
Sylvain Gugger's avatar
Sylvain Gugger committed
1636

1637
        We currently support extractive question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
1638

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1639
        Arguments:
Sylvain Gugger's avatar
Sylvain Gugger committed
1640
1641
            question (:obj:`str` or :obj:`List[str]`): The question(s) asked.
            context (:obj:`str` or :obj:`List[str]`): The context(s) in which we will look for the answer.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1642
1643

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1644
1645
            One or a list of :class:`~transformers.SquadExample`: The corresponding :class:`~transformers.SquadExample`
            grouping question and context.
1646
1647
        """
        if isinstance(question, list):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1648
1649
1650
1651
            return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
        else:
            return SquadExample(None, question, context, None, None, None)

1652
    def __call__(self, *args, **kwargs):
1653
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1654
1655
        Answer the question(s) given as inputs by using the context(s).

1656
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1657
1658
1659
            args (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`):
                One or several :class:`~transformers.SquadExample` containing the question and context.
            X (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1660
1661
                One or several :class:`~transformers.SquadExample` containing the question and context (will be treated
                the same way as if passed as the first positional argument).
Sylvain Gugger's avatar
Sylvain Gugger committed
1662
            data (:class:`~transformers.SquadExample` or a list of :class:`~transformers.SquadExample`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1663
1664
                One or several :class:`~transformers.SquadExample` containing the question and context (will be treated
                the same way as if passed as the first positional argument).
Sylvain Gugger's avatar
Sylvain Gugger committed
1665
1666
1667
            question (:obj:`str` or :obj:`List[str]`):
                One or several question(s) (must be used in conjunction with the :obj:`context` argument).
            context (:obj:`str` or :obj:`List[str]`):
1668
                One or several context(s) associated with the question(s) (must be used in conjunction with the
Sylvain Gugger's avatar
Sylvain Gugger committed
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
                :obj:`question` argument).
            topk (:obj:`int`, `optional`, defaults to 1):
                The number of answers to return (will be chosen by order of likelihood).
            doc_stride (:obj:`int`, `optional`, defaults to 128):
                If the context is too long to fit with the question for the model, it will be split in several chunks
                with some overlap. This argument controls the size of that overlap.
            max_answer_len (:obj:`int`, `optional`, defaults to 15):
                The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
            max_seq_len (:obj:`int`, `optional`, defaults to 384):
                The maximum length of the total sentence (context + question) after tokenization. The context will be
                split in several chunks (using :obj:`doc_stride`) if needed.
            max_question_len (:obj:`int`, `optional`, defaults to 64):
                The maximum length of the question after tokenization. It will be truncated if needed.
            handle_impossible_answer (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not we accept impossible as an answer.

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1686
            A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
1687
1688
1689
1690
1691

            - **score** (:obj:`float`) -- The probability associated to the answer.
            - **start** (:obj:`int`) -- The start index of the answer (in the tokenized version of the input).
            - **end** (:obj:`int`) -- The end index of the answer (in the tokenized version of the input).
            - **answer** (:obj:`str`) -- The answer to the question.
1692
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1693
        # Set defaults values
1694
1695
1696
1697
1698
        kwargs.setdefault("topk", 1)
        kwargs.setdefault("doc_stride", 128)
        kwargs.setdefault("max_answer_len", 15)
        kwargs.setdefault("max_seq_len", 384)
        kwargs.setdefault("max_question_len", 64)
1699
        kwargs.setdefault("handle_impossible_answer", False)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1700

1701
1702
        if kwargs["topk"] < 1:
            raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1703

1704
1705
        if kwargs["max_answer_len"] < 1:
            raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1706
1707

        # Convert inputs to features
1708
        examples = self._args_parser(*args, **kwargs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1709
1710
        features_list = [
            squad_convert_examples_to_features(
1711
1712
1713
1714
1715
                examples=[example],
                tokenizer=self.tokenizer,
                max_seq_length=kwargs["max_seq_len"],
                doc_stride=kwargs["doc_stride"],
                max_query_length=kwargs["max_question_len"],
1716
                padding_strategy=PaddingStrategy.MAX_LENGTH.value,
1717
                is_training=False,
1718
                tqdm_enabled=False,
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1719
1720
1721
            )
            for example in examples
        ]
Rishabh Manoj's avatar
Rishabh Manoj committed
1722
1723
        all_answers = []
        for features, example in zip(features_list, examples):
Patrick von Platen's avatar
Patrick von Platen committed
1724
1725
            model_input_names = self.tokenizer.model_input_names + ["input_ids"]
            fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
Rishabh Manoj's avatar
Rishabh Manoj committed
1726
1727
1728
1729
1730

            # Manage tensor allocation on correct device
            with self.device_placement():
                if self.framework == "tf":
                    fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1731
                    start, end = self.model(fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1732
1733
1734
1735
1736
                    start, end = start.numpy(), end.numpy()
                else:
                    with torch.no_grad():
                        # Retrieve the score for the context tokens only (removing question tokens)
                        fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1737
                        start, end = self.model(**fw_args)[:2]
Rishabh Manoj's avatar
Rishabh Manoj committed
1738
1739
                        start, end = start.cpu().numpy(), end.cpu().numpy()

1740
            min_null_score = 1000000  # large and positive
Rishabh Manoj's avatar
Rishabh Manoj committed
1741
1742
            answers = []
            for (feature, start_, end_) in zip(features, start, end):
1743
1744
                # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
                undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask
Rishabh Manoj's avatar
Rishabh Manoj committed
1745

1746
1747
1748
1749
1750
1751
                # Generate mask
                undesired_tokens_mask = undesired_tokens == 0.0

                # Make sure non-context indexes in the tensor cannot contribute to the softmax
                start_ = np.where(undesired_tokens_mask, -10000.0, start_)
                end_ = np.where(undesired_tokens_mask, -10000.0, end_)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
1752
1753
1754
1755
1756

                # Normalize logits and spans to retrieve the answer
                start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
                end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

1757
1758
1759
                if kwargs["handle_impossible_answer"]:
                    min_null_score = min(min_null_score, (start_[0] * end_[0]).item())

1760
1761
1762
                # Mask CLS
                start_[0] = end_[0] = 0.0

Rishabh Manoj's avatar
Rishabh Manoj committed
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
                starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
                char_to_word = np.array(example.char_to_word_offset)

                # Convert the answer (tokens) back to the original text
                answers += [
                    {
                        "score": score.item(),
                        "start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
                        "end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
                        "answer": " ".join(
                            example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
                        ),
                    }
                    for s, e, score in zip(starts, ends, scores)
                ]
1778
1779
1780
1781

            if kwargs["handle_impossible_answer"]:
                answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1782
1783
1784
            answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
            all_answers += answers

Rishabh Manoj's avatar
Rishabh Manoj committed
1785
        if len(all_answers) == 1:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1786
            return all_answers[0]
Rishabh Manoj's avatar
Rishabh Manoj committed
1787
        return all_answers
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1788
1789

    def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
1790
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1791
1792
        Take the output of any :obj:`ModelForQuestionAnswering` and will generate probabilities for each span to be the
        actual answer.
Sylvain Gugger's avatar
Sylvain Gugger committed
1793

Sylvain Gugger's avatar
Sylvain Gugger committed
1794
1795
1796
        In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or
        answer end position being before the starting position. The method supports output the k-best answer through
        the topk argument.
1797
1798

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1799
1800
1801
1802
            start (:obj:`np.ndarray`): Individual start probabilities for each token.
            end (:obj:`np.ndarray`): Individual end probabilities for each token.
            topk (:obj:`int`): Indicates how many possible answer span(s) to extract from the model output.
            max_answer_len (:obj:`int`): Maximum size of the answer to extract from the model's output.
1803
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
        # Ensure we have batch axis
        if start.ndim == 1:
            start = start[None]

        if end.ndim == 1:
            end = end[None]

        # Compute the score of each tuple(start, end) to be the real answer
        outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))

        # Remove candidate with end < start and end - start > max_answer_len
        candidates = np.tril(np.triu(outer), max_answer_len - 1)

        #  Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
        scores_flat = candidates.flatten()
        if topk == 1:
            idx_sort = [np.argmax(scores_flat)]
        elif len(scores_flat) < topk:
            idx_sort = np.argsort(-scores_flat)
        else:
            idx = np.argpartition(-scores_flat, topk)[0:topk]
            idx_sort = idx[np.argsort(-scores_flat[idx])]

        start, end = np.unravel_index(idx_sort, candidates.shape)[1:]
        return start, end, candidates[0, start, end]

Sylvain Gugger's avatar
Sylvain Gugger committed
1830
    def span_to_answer(self, text: str, start: int, end: int) -> Dict[str, Union[str, int]]:
1831
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1832
        When decoding from token probabilities, this method maps token indexes to actual word in the initial context.
1833
1834

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1835
1836
1837
            text (:obj:`str`): The actual context to extract the answer from.
            start (:obj:`int`): The answer starting token index.
            end (:obj:`int`): The answer end token index.
1838
1839

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1840
            Dictionary like :obj:`{'answer': str, 'start': int, 'end': int}`
1841
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
        words = []
        token_idx = char_start_idx = char_end_idx = chars_idx = 0

        for i, word in enumerate(text.split(" ")):
            token = self.tokenizer.tokenize(word)

            # Append words if they are in the span
            if start <= token_idx <= end:
                if token_idx == start:
                    char_start_idx = chars_idx

                if token_idx == end:
                    char_end_idx = chars_idx + len(word)

                words += [word]

            # Stop if we went over the end of the answer
            if token_idx > end:
                break

            # Append the subtokenization length to the running index
            token_idx += len(token)
            chars_idx += len(word) + 1

        # Join text with spaces
1867
1868
1869
1870
1871
        return {
            "answer": " ".join(words),
            "start": max(0, char_start_idx),
            "end": min(len(text), char_end_idx),
        }
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1872
1873


Sylvain Gugger's avatar
Sylvain Gugger committed
1874
@add_end_docstrings(PIPELINE_INIT_ARGS)
1875
1876
class SummarizationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1877
1878
    Summarize news articles and other documents.

Sylvain Gugger's avatar
Sylvain Gugger committed
1879
1880
    This summarizing pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
    identifier: :obj:`"summarization"`.
Sylvain Gugger's avatar
Sylvain Gugger committed
1881

Sylvain Gugger's avatar
Sylvain Gugger committed
1882
1883
1884
    The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
    currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'. See the up-to-date
    list of available models on `huggingface.co/models <https://huggingface.co/models?filter=summarization>`__.
1885
1886
1887

    Usage::

1888
        # use bart in pytorch
1889
        summarizer = pipeline("summarization")
1890
1891
1892
1893
1894
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)

        # use t5 in tf
        summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
        summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
1895
1896
    """

1897
    def __init__(self, *args, **kwargs):
1898
        kwargs.update(task="summarization")
1899
1900
1901
1902
1903
        super().__init__(*args, **kwargs)

        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
1904

1905
    def __call__(
1906
        self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
1907
1908
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1909
        Summarize the text(s) given as inputs.
1910

Sylvain Gugger's avatar
Sylvain Gugger committed
1911
1912
1913
1914
1915
1916
        Args:
            documents (`str` or :obj:`List[str]`):
                One or several articles (or one list of articles) to summarize.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
1917
                Whether or not to include the tensors of predictions (as token indices) in the outputs.
Sylvain Gugger's avatar
Sylvain Gugger committed
1918
1919
1920
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1921
1922
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
                corresponding to your framework `here <./model.html#generative-models>`__).
1923

Sylvain Gugger's avatar
Sylvain Gugger committed
1924
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1925
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
1926

Sylvain Gugger's avatar
Sylvain Gugger committed
1927
1928
            - **summary_text** (:obj:`str`, present when ``return_text=True``) -- The summary of the corresponding
              input.
Sylvain Gugger's avatar
Sylvain Gugger committed
1929
1930
            - **summary_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) --
              The token ids of the summary.
1931
1932
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
        assert len(documents) > 0, "Please provide a document to summarize"

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

        if isinstance(documents[0], list):
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"

            documents = ([prefix + document for document in documents[0]],)
1943
            padding = True
1944
1945
1946

        elif isinstance(documents[0], str):
            documents = (prefix + documents[0],)
1947
            padding = False
1948
1949
1950
1951
1952
1953
1954
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
                    documents[0]
                )
            )

1955
        with self.device_placement():
1956
            inputs = self._parse_and_tokenize(*documents, padding=padding)
1957
1958
1959
1960
1961

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]
            elif self.framework == "tf":
1962
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()
1963

1964
1965
            min_length = generate_kwargs.get("min_length", self.model.config.min_length)
            if input_length < min_length // 2:
1966
                logger.warning(
1967
                    "Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
1968
                        min_length, input_length
1969
1970
1971
                    )
                )

1972
1973
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length < max_length:
1974
                logger.warning(
1975
                    "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
1976
                        max_length, input_length
1977
1978
1979
                    )
                )

1980
            summaries = self.model.generate(
Lysandre's avatar
Lysandre committed
1981
1982
1983
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
1984
            )
1985

1986
1987
1988
1989
1990
1991
1992
            results = []
            for summary in summaries:
                record = {}
                if return_tensors:
                    record["summary_token_ids"] = summary
                if return_text:
                    record["summary_text"] = self.tokenizer.decode(
Lysandre's avatar
Lysandre committed
1993
1994
1995
                        summary,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
1996
1997
1998
1999
2000
                    )
                results.append(record)
            return results


Sylvain Gugger's avatar
Sylvain Gugger committed
2001
@add_end_docstrings(PIPELINE_INIT_ARGS)
2002
2003
2004
2005
class TranslationPipeline(Pipeline):
    """
    Translates from one language to another.

Sylvain Gugger's avatar
Sylvain Gugger committed
2006
2007
    This translation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
    identifier: :obj:`"translation_xx_to_yy"`.
2008

Sylvain Gugger's avatar
Sylvain Gugger committed
2009
2010
2011
    The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
    up-to-date list of available models on `huggingface.co/models
    <https://huggingface.co/models?filter=translation>`__.
2012

Sylvain Gugger's avatar
Sylvain Gugger committed
2013
2014
2015
    Usage::
        en_fr_translator = pipeline("translation_en_to_fr")
        en_fr_translator("How old are you?")
2016
2017
    """

2018
2019
2020
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

2021
2022
2023
        self.check_model_type(
            TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )
2024

2025
    def __call__(
2026
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
2027
2028
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
2029
2030
        Translate the text(s) given as inputs.

2031
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2032
2033
2034
            args (:obj:`str` or :obj:`List[str]`):
                Texts to be translated.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
2035
                Whether or not to include the tensors of predictions (as token indices) in the outputs.
Sylvain Gugger's avatar
Sylvain Gugger committed
2036
2037
2038
2039
2040
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2041
2042
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
                corresponding to your framework `here <./model.html#generative-models>`__).
2043

Sylvain Gugger's avatar
Sylvain Gugger committed
2044
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2045
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
2046

Sylvain Gugger's avatar
Sylvain Gugger committed
2047
2048
2049
            - **translation_text** (:obj:`str`, present when ``return_text=True``) -- The translation.
            - **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the translation.
2050
2051
2052
2053
2054
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"

        prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

2055
        if isinstance(args[0], list):
2056
2057
2058
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
2059
            args = ([prefix + text for text in args[0]],)
2060
            padding = True
2061

2062
2063
        elif isinstance(args[0], str):
            args = (prefix + args[0],)
2064
            padding = False
2065
2066
2067
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
2068
                    args[0]
2069
2070
2071
2072
                )
            )

        with self.device_placement():
2073
            inputs = self._parse_and_tokenize(*args, padding=padding)
2074
2075
2076
2077
2078
2079
2080
2081

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

2082
2083
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            if input_length > 0.9 * max_length:
2084
2085
                logger.warning(
                    "Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
2086
                        input_length, max_length
2087
2088
2089
2090
                    )
                )

            translations = self.model.generate(
Lysandre's avatar
Lysandre committed
2091
2092
2093
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
            )
            results = []
            for translation in translations:
                record = {}
                if return_tensors:
                    record["translation_token_ids"] = translation
                if return_text:
                    record["translation_text"] = self.tokenizer.decode(
                        translation,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
2105
2106
2107
2108
2109
                    )
                results.append(record)
            return results


2110
2111
2112
2113
2114
@add_end_docstrings(PIPELINE_INIT_ARGS)
class Text2TextGenerationPipeline(Pipeline):
    """
    Pipeline for text to text generation using seq2seq models.

Sylvain Gugger's avatar
Sylvain Gugger committed
2115
2116
    This Text2TextGenerationPipeline pipeline can currently be loaded from :func:`~transformers.pipeline` using the
    following task identifier: :obj:`"text2text-generation"`.
2117

Sylvain Gugger's avatar
Sylvain Gugger committed
2118
2119
    The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
    up-to-date list of available models on `huggingface.co/models <https://huggingface.co/models?filter=seq2seq>`__.
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145

    Usage::

        text2text_generator = pipeline("text2text-generation")
        text2text_generator("question: What is 42 ? context: 42 is the answer to life, the universe and everything")
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.check_model_type(
            TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
            if self.framework == "tf"
            else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
        )

    def __call__(
        self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
    ):
        r"""
        Generate the output text(s) using text(s) given as inputs.

        Args:
            args (:obj:`str` or :obj:`List[str]`):
                Input text for the encoder.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
2146
                Whether or not to include the tensors of predictions (as token indices) in the outputs.
2147
2148
2149
2150
2151
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2152
2153
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
                corresponding to your framework `here <./model.html#generative-models>`__).
2154
2155

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2156
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204

            - **generated_text** (:obj:`str`, present when ``return_text=True``) -- The generated text.
            - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the generated text.
        """
        assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"

        if isinstance(args[0], list):
            assert (
                self.tokenizer.pad_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
            padding = True

        elif isinstance(args[0], str):
            padding = False
        else:
            raise ValueError(
                " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
                    args[0]
                )
            )

        with self.device_placement():
            inputs = self._parse_and_tokenize(*args, padding=padding)

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)

            generations = self.model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
            )
            results = []
            for generation in generations:
                record = {}
                if return_tensors:
                    record["generated_token_ids"] = generation
                if return_text:
                    record["generated_text"] = self.tokenizer.decode(
                        generation,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                results.append(record)
            return results


2205
2206
2207
class Conversation:
    """
    Utility class containing a conversation and its history. This class is meant to be used as an input to the
Sylvain Gugger's avatar
Sylvain Gugger committed
2208
2209
2210
    :class:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the
    addition of new user input and generated model responses. A conversation needs to contain an unprocessed user input
    before being passed to the :class:`~transformers.ConversationalPipeline`. This user input is either created when
2211
2212
    the class is instantiated, or by calling :obj:`conversational_pipeline.append_response("input")` after a
    conversation turn.
Sylvain Gugger's avatar
Sylvain Gugger committed
2213
2214
2215
2216
2217
2218
2219
2220
2221

    Arguments:
        text (:obj:`str`, `optional`):
            The initial user input to start the conversation. If not provided, a user input needs to be provided
            manually using the :meth:`~transformers.Conversation.add_user_input` method before the conversation can
            begin.
        conversation_id (:obj:`uuid.UUID`, `optional`):
            Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
            conversation.
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246

    Usage::

        conversation = Conversation("Going to the movies tonight - any suggestions?")

        # Steps usually performed by the model when generating a response:
        # 1. Mark the user input as processed (moved to the history)
        conversation.mark_processed()
        # 2. Append a mode response
        conversation.append_response("The Big lebowski.")

        conversation.add_user_input("Is it good?")
    """

    def __init__(self, text: str = None, conversation_id: UUID = None):
        if not conversation_id:
            conversation_id = uuid.uuid4()
        self.uuid: UUID = conversation_id
        self.past_user_inputs: List[str] = []
        self.generated_responses: List[str] = []
        self.history: List[int] = []
        self.new_user_input: Optional[str] = text

    def add_user_input(self, text: str, overwrite: bool = False):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2247
2248
        Add a user input to the conversation for the next round. This populates the internal :obj:`new_user_input`
        field.
2249
2250

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2251
2252
2253
            text (:obj:`str`): The user input for the next conversation round.
            overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not existing and unprocessed user input should be overwritten when this function is called.
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
        """
        if self.new_user_input:
            if overwrite:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format(
                        self.new_user_input, text
                    )
                )
                self.new_user_input = text
            else:
                logger.warning(
                    'User input added while unprocessed input was existing: "{}" new input ignored: "{}". '
                    "Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text)
                )
        else:
            self.new_user_input = text

    def mark_processed(self):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2273
2274
        Mark the conversation as processed (moves the content of :obj:`new_user_input` to :obj:`past_user_inputs`) and
        empties the :obj:`new_user_input` field.
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
        """
        if self.new_user_input:
            self.past_user_inputs.append(self.new_user_input)
        self.new_user_input = None

    def append_response(self, response: str):
        """
        Append a response to the list of generated responses.

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2285
            response (:obj:`str`): The model generated response.
2286
2287
2288
2289
2290
        """
        self.generated_responses.append(response)

    def set_history(self, history: List[int]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2291
2292
        Updates the value of the history of the conversation. The history is represented by a list of :obj:`token_ids`.
        The history is used by the model to generate responses based on the previous conversation turns.
2293
2294

        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2295
            history (:obj:`List[int]`): History of tokens provided and generated for this conversation.
2296
2297
2298
2299
2300
2301
2302
2303
        """
        self.history = history

    def __repr__(self):
        """
        Generates a string representation of the conversation.

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2304
            :obj:`str`:
2305

Sylvain Gugger's avatar
Sylvain Gugger committed
2306
2307
            Example: Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user >> Going to the movies tonight - any
            suggestions? bot >> The Big Lebowski
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
        """
        output = "Conversation id: {} \n".format(self.uuid)
        for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
            output += "user >> {} \n".format(user_input)
            output += "bot >> {} \n".format(generated_response)
        if self.new_user_input is not None:
            output += "user >> {} \n".format(self.new_user_input)
        return output


Sylvain Gugger's avatar
Sylvain Gugger committed
2318
2319
2320
2321
2322
2323
2324
@add_end_docstrings(
    PIPELINE_INIT_ARGS,
    r"""
        min_length_for_response (:obj:`int`, `optional`, defaults to 32):
            The minimum length (in number of tokens) for a response.
    """,
)
2325
2326
2327
2328
class ConversationalPipeline(Pipeline):
    """
    Multi-turn conversational pipeline.

Sylvain Gugger's avatar
Sylvain Gugger committed
2329
2330
    This conversational pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
    identifier: :obj:`"conversational"`.
Sylvain Gugger's avatar
Sylvain Gugger committed
2331
2332

    The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task,
Sylvain Gugger's avatar
Sylvain Gugger committed
2333
2334
2335
    currently: `'microsoft/DialoGPT-small'`, `'microsoft/DialoGPT-medium'`, `'microsoft/DialoGPT-large'`. See the
    up-to-date list of available models on `huggingface.co/models
    <https://huggingface.co/models?filter=conversational>`__.
Sylvain Gugger's avatar
Sylvain Gugger committed
2336

2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
    Usage::

        conversational_pipeline = pipeline("conversational")

        conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
        conversation_2 = Conversation("What's the last book you have read?")

        conversational_pipeline([conversation_1, conversation_2])

        conversation_1.add_user_input("Is it an action movie?")
        conversation_2.add_user_input("What is the genre of this book?")

        conversational_pipeline([conversation_1, conversation_2])
    """

    def __init__(self, min_length_for_response=32, *args, **kwargs):
        super().__init__(*args, **kwargs)
2354
2355

        # We need at least an eos_token
2356
        assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
2357
2358
2359
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

2360
2361
2362
2363
2364
2365
2366
2367
2368
        self.min_length_for_response = min_length_for_response

    def __call__(
        self,
        conversations: Union[Conversation, List[Conversation]],
        clean_up_tokenization_spaces=True,
        **generate_kwargs
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
2369
2370
        Generate responses for the conversation(s) given as inputs.

2371
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
2372
2373
2374
2375
2376
            conversations (a :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`):
                Conversations to generate responses for.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2377
2378
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
                corresponding to your framework `here <./model.html#generative-models>`__).
2379
2380

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2381
2382
            :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`: Conversation(s) with
            updated generated responses for those containing a new user input.
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
        """

        # Input validation
        if isinstance(conversations, list):
            for conversation in conversations:
                assert isinstance(
                    conversation, Conversation
                ), "DialoguePipeline expects a Conversation or list of Conversations as an input"
                if conversation.new_user_input is None:
                    raise ValueError(
                        "Conversation with UUID {} does not contain new user input to process. "
                        "Add user inputs with the conversation's `add_user_input` method".format(
                            type(conversation.uuid)
                        )
                    )
            assert (
                self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
            ), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
        elif isinstance(conversations, Conversation):
            conversations = [conversations]
        else:
            raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")

        with self.device_placement():

            inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
            histories = [conversation.history for conversation in conversations]
            max_length = generate_kwargs.get("max_length", self.model.config.max_length)
            inputs = self._concat_inputs_history(inputs, histories, max_length)

            if self.framework == "pt":
                inputs = self.ensure_tensor_on_device(**inputs)
                input_length = inputs["input_ids"].shape[-1]

            elif self.framework == "tf":
                input_length = tf.shape(inputs["input_ids"])[-1].numpy()

            if input_length > 0.9 * max_length:
                logger.warning(
                    "Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
                    "You might consider trimming the early phase of the conversation".format(input_length, max_length)
                )
            generated_responses = self.model.generate(
Lysandre's avatar
Lysandre committed
2426
2427
2428
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generate_kwargs,
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
            )

            cleaned_history = self._clean_padding_history(generated_responses)
            output = []
            for conversation_index, conversation in enumerate(conversations):
                conversation.mark_processed()
                conversation.generated_responses.append(
                    self.tokenizer.decode(
                        cleaned_history[conversation_index][input_length:],
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                )
                conversation.set_history(cleaned_history[conversation_index])
                output.append(conversation)
            if len(output) == 1:
                return output[0]
            else:
                return output

2449
    def _parse_and_tokenize(self, inputs, **kwargs):
2450
2451
2452
2453
        """
        Parse arguments and tokenize, adding an EOS token at the end of the user input
        """
        # Parse arguments
2454
        inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
2455
2456
2457
2458
2459
2460
2461
2462
        for input in inputs:
            input.append(self.tokenizer.eos_token_id)
        return inputs

    def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
        """
        Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
        an input:
Sylvain Gugger's avatar
Sylvain Gugger committed
2463

2464
            - at the end of the concatenated history and new user input, so that all input to the model have the same
Sylvain Gugger's avatar
Sylvain Gugger committed
2465
              length
2466
2467
2468
2469
2470
2471
2472
2473
2474
            - at the end of the generated response, as some responses will be longer than others
        This method cleans up these padding token so that the history for each conversation is not impacted by the
        batching process.
        """
        outputs = []
        for sequence in generated_tensor:
            sequence_tokens = []
            is_previous_pad = False
            for token in sequence:
2475
                if token == self.tokenizer.pad_token_id:
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
                    if is_previous_pad:
                        continue
                    else:
                        is_previous_pad = True
                else:
                    is_previous_pad = False
                if self.framework == "pt":
                    sequence_tokens.append(token.item())
                else:
                    sequence_tokens.append(int(token.numpy()))

            outputs.append(sequence_tokens)
        return outputs

    def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
        """
        Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
        """
        outputs = []
        for new_input, history in zip(inputs, histories):
            if history is not None:
                new_input = history + new_input
            if len(new_input) > max_length - self.min_length_for_response:
                cutoff_eos_index = 0
                while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
                    if cutoff_eos_index >= len(new_input):
                        break
                    cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
                    if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
                        break
                    else:
                        new_input = new_input[cutoff_eos_index + 1 :]
            outputs.append(new_input)
2509
2510
        padded_outputs = self.tokenizer.pad(
            {"input_ids": outputs}, padding="longest", return_attention_mask=True, return_tensors=self.framework
2511
        )
2512
        return padded_outputs
2513
2514


2515
# Register all the supported tasks here
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2516
SUPPORTED_TASKS = {
2517
2518
2519
2520
    "feature-extraction": {
        "impl": FeatureExtractionPipeline,
        "tf": TFAutoModel if is_tf_available() else None,
        "pt": AutoModel if is_torch_available() else None,
2521
        "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
2522
    },
2523
2524
2525
2526
2527
2528
    "sentiment-analysis": {
        "impl": TextClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {
2529
2530
                "pt": "distilbert-base-uncased-finetuned-sst-2-english",
                "tf": "distilbert-base-uncased-finetuned-sst-2-english",
2531
            },
2532
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2533
    },
2534
    "ner": {
2535
        "impl": TokenClassificationPipeline,
2536
2537
2538
2539
        "tf": TFAutoModelForTokenClassification if is_tf_available() else None,
        "pt": AutoModelForTokenClassification if is_torch_available() else None,
        "default": {
            "model": {
Julien Chaumond's avatar
Julien Chaumond committed
2540
2541
                "pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
                "tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
2542
            },
2543
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2544
    },
2545
2546
2547
2548
2549
    "question-answering": {
        "impl": QuestionAnsweringPipeline,
        "tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
        "pt": AutoModelForQuestionAnswering if is_torch_available() else None,
        "default": {
Lysandre's avatar
E231  
Lysandre committed
2550
            "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
2551
2552
        },
    },
Julien Chaumond's avatar
Julien Chaumond committed
2553
2554
    "fill-mask": {
        "impl": FillMaskPipeline,
2555
        "tf": TFAutoModelForMaskedLM if is_tf_available() else None,
2556
        "pt": AutoModelForMaskedLM if is_torch_available() else None,
2557
        "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
Julien Chaumond's avatar
Julien Chaumond committed
2558
    },
2559
2560
    "summarization": {
        "impl": SummarizationPipeline,
2561
        "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
2562
2563
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
        "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
2564
    },
2565
2566
    # This task is a special case as it's parametrized by SRC, TGT languages.
    "translation": {
2567
        "impl": TranslationPipeline,
2568
        "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
2569
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
2570
2571
2572
2573
2574
        "default": {
            ("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
            ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
            ("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
        },
2575
    },
2576
2577
2578
2579
2580
2581
    "text2text-generation": {
        "impl": Text2TextGenerationPipeline,
        "tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
        "pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
        "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
    },
2582
2583
    "text-generation": {
        "impl": TextGenerationPipeline,
2584
        "tf": TFAutoModelForCausalLM if is_tf_available() else None,
2585
        "pt": AutoModelForCausalLM if is_torch_available() else None,
2586
        "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
2587
    },
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
    "zero-shot-classification": {
        "impl": ZeroShotClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
            "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
        },
    },
2598
2599
2600
2601
2602
2603
    "conversational": {
        "impl": ConversationalPipeline,
        "tf": TFAutoModelForCausalLM if is_tf_available() else None,
        "pt": AutoModelForCausalLM if is_torch_available() else None,
        "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
    },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2604
2605
2606
}


2607
2608
def check_task(task: str) -> Tuple[Dict, Any]:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
2609
2610
    Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and
    default models if they exist.
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627

    Args:
        task (:obj:`str`):
            The task defining which pipeline will be returned. Currently accepted tasks are:

            - :obj:`"feature-extraction"`
            - :obj:`"sentiment-analysis"`
            - :obj:`"ner"`
            - :obj:`"question-answering"`
            - :obj:`"fill-mask"`
            - :obj:`"summarization"`
            - :obj:`"translation_xx_to_yy"`
            - :obj:`"translation"`
            - :obj:`"text-generation"`
            - :obj:`"conversational"`

    Returns:
2628
2629
        (task_defaults:obj:`dict`, task_options: (:obj:`tuple`, None)) The actual dictionary required to initialize the
        pipeline and some extra task options for parametrized tasks like "translation_XX_to_YY"
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648


    """
    if task in SUPPORTED_TASKS:
        targeted_task = SUPPORTED_TASKS[task]
        return targeted_task, None

    if task.startswith("translation"):
        tokens = task.split("_")
        if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
            targeted_task = SUPPORTED_TASKS["translation"]
            return targeted_task, (tokens[1], tokens[3])
        raise KeyError("Invalid translation task {}, use 'translation_XX_to_YY' format".format(task))

    raise KeyError(
        "Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()) + ["translation_XX_to_YY"])
    )


2649
2650
2651
2652
2653
def pipeline(
    task: str,
    model: Optional = None,
    config: Optional[Union[str, PretrainedConfig]] = None,
    tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
2654
    framework: Optional[str] = None,
2655
    use_fast: bool = False,
2656
2657
    **kwargs
) -> Pipeline:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2658
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
2659
    Utility factory method to build a :class:`~transformers.Pipeline`.
Lysandre Debut's avatar
Lysandre Debut committed
2660

Sylvain Gugger's avatar
Sylvain Gugger committed
2661
    Pipelines are made of:
Lysandre Debut's avatar
Lysandre Debut committed
2662

Sylvain Gugger's avatar
Sylvain Gugger committed
2663
2664
2665
        - A :doc:`tokenizer <tokenizer>` in charge of mapping raw textual input to token.
        - A :doc:`model <model>` to make predictions from the inputs.
        - Some (optional) post processing for enhancing model's output.
Lysandre Debut's avatar
Lysandre Debut committed
2666
2667
2668
2669
2670

    Args:
        task (:obj:`str`):
            The task defining which pipeline will be returned. Currently accepted tasks are:

Sylvain Gugger's avatar
Sylvain Gugger committed
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
            - :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
            - :obj:`"sentiment-analysis"`: will return a :class:`~transformers.TextClassificationPipeline`.
            - :obj:`"ner"`: will return a :class:`~transformers.TokenClassificationPipeline`.
            - :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
            - :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
            - :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.
            - :obj:`"translation_xx_to_yy"`: will return a :class:`~transformers.TranslationPipeline`.
            - :obj:`"text-generation"`: will return a :class:`~transformers.TextGenerationPipeline`.
            - :obj:`"conversation"`: will return a :class:`~transformers.ConversationalPipeline`.
        model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`):
            The model that will be used by the pipeline to make predictions. This can be a model identifier or an
            actual instance of a pretrained model inheriting from :class:`~transformers.PreTrainedModel` (for PyTorch)
            or :class:`~transformers.TFPreTrainedModel` (for TensorFlow).

            If not provided, the default for the :obj:`task` will be loaded.
        config (:obj:`str` or :obj:`~transformers.PretrainedConfig`, `optional`):
            The configuration that will be used by the pipeline to instantiate the model. This can be a model
            identifier or an actual pretrained model configuration inheriting from
Lysandre Debut's avatar
Lysandre Debut committed
2689
2690
            :class:`~transformers.PretrainedConfig`.

2691
2692
2693
            If not provided, the default configuration file for the requested model will be used. That means that if
            :obj:`model` is given, its default configuration will be used. However, if :obj:`model` is not supplied,
            this :obj:`task`'s default model's config is used instead.
Sylvain Gugger's avatar
Sylvain Gugger committed
2694
2695
        tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`):
            The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
Sylvain Gugger's avatar
Sylvain Gugger committed
2696
            identifier or an actual pretrained tokenizer inheriting from :class:`~transformers.PreTrainedTokenizer`.
Lysandre Debut's avatar
Lysandre Debut committed
2697

2698
2699
2700
2701
            If not provided, the default tokenizer for the given :obj:`model` will be loaded (if it is a string). If
            :obj:`model` is not specified or not a string, then the default tokenizer for :obj:`config` is loaded (if
            it is a string). However, if :obj:`config` is also not given or not a string, then the default tokenizer
            for the given :obj:`task` will be loaded.
Sylvain Gugger's avatar
Sylvain Gugger committed
2702
2703
2704
        framework (:obj:`str`, `optional`):
            The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
            must be installed.
Lysandre Debut's avatar
Lysandre Debut committed
2705

Sylvain Gugger's avatar
Sylvain Gugger committed
2706
2707
2708
            If no framework is specified, will default to the one currently installed. If no framework is specified and
            both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
            is provided.
2709
2710
        use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
Sylvain Gugger's avatar
Sylvain Gugger committed
2711
2712
2713
        kwargs:
            Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
            corresponding pipeline class for possible values).
Lysandre Debut's avatar
Lysandre Debut committed
2714
2715

    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2716
        :class:`~transformers.Pipeline`: A suitable pipeline for the task.
Lysandre Debut's avatar
Lysandre Debut committed
2717
2718
2719

    Examples::

2720
        >>> from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
Lysandre Debut's avatar
Lysandre Debut committed
2721

2722
2723
        >>> # Sentiment analysis pipeline
        >>> pipeline('sentiment-analysis')
Lysandre Debut's avatar
Lysandre Debut committed
2724

2725
2726
        >>> # Question answering pipeline, specifying the checkpoint identifier
        >>> pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased')
Lysandre Debut's avatar
Lysandre Debut committed
2727

2728
2729
2730
2731
        >>> # Named entity recognition pipeline, passing in a specific model and tokenizer
        >>> model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
        >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        >>> pipeline('ner', model=model, tokenizer=tokenizer)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2732
    """
2733
    # Retrieve the task
2734
    targeted_task, task_options = check_task(task)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
2735

2736
    # Use default model/config/tokenizer for the task if no model is provided
2737
    if model is None:
2738
        # At that point framework might still be undetermined
2739
        model = get_default_model(targeted_task, framework, task_options)
2740
2741
2742
2743

    framework = framework or get_framework(model)

    task_class, model_class = targeted_task["impl"], targeted_task[framework]
2744

2745
2746
    # Try to infer tokenizer from model or config name (if provided as str)
    if tokenizer is None:
2747
        if isinstance(model, str):
2748
            tokenizer = model
2749
        elif isinstance(config, str):
2750
2751
2752
            tokenizer = config
        else:
            # Impossible to guest what is the right tokenizer here
2753
2754
            raise Exception(
                "Impossible to guess which tokenizer to use. "
2755
                "Please provided a PretrainedTokenizer class or a path/identifier to a pretrained tokenizer."
2756
            )
2757

Lysandre Debut's avatar
Lysandre Debut committed
2758
    modelcard = None
2759
    # Try to infer modelcard from model or config name (if provided as str)
Lysandre Debut's avatar
Lysandre Debut committed
2760
2761
2762
2763
    if isinstance(model, str):
        modelcard = model
    elif isinstance(config, str):
        modelcard = config
2764
2765

    # Instantiate tokenizer if needed
2766
2767
2768
    if isinstance(tokenizer, (str, tuple)):
        if isinstance(tokenizer, tuple):
            # For tuple we have (tokenizer name, {kwargs})
2769
2770
            use_fast = tokenizer[1].pop("use_fast", use_fast)
            tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], use_fast=use_fast, **tokenizer[1])
2771
        else:
2772
            tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=use_fast)
2773
2774
2775
2776

    # Instantiate config if needed
    if isinstance(config, str):
        config = AutoConfig.from_pretrained(config)
2777

thomwolf's avatar
thomwolf committed
2778
2779
2780
2781
    # Instantiate modelcard if needed
    if isinstance(modelcard, str):
        modelcard = ModelCard.from_pretrained(modelcard)

2782
    # Instantiate model if needed
2783
    if isinstance(model, str):
2784
2785
        # Handle transparent TF/PT model conversion
        model_kwargs = {}
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
        if framework == "pt" and model.endswith(".h5"):
            model_kwargs["from_tf"] = True
            logger.warning(
                "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
                "Trying to load the model with PyTorch."
            )
        elif framework == "tf" and model.endswith(".bin"):
            model_kwargs["from_pt"] = True
            logger.warning(
                "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
                "Trying to load the model with Tensorflow."
            )
2798
        model = model_class.from_pretrained(model, config=config, **model_kwargs)
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
        if task == "translation" and model.config.task_specific_params:
            for key in model.config.task_specific_params:
                if key.startswith("translation"):
                    task = key
                    warnings.warn(
                        '"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{}"'.format(
                            task
                        ),
                        UserWarning,
                    )
                    break
2810

2811
    return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)