pipelines.py 39.9 KB
Newer Older
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

Morgan Funtowicz's avatar
Morgan Funtowicz committed
16

17
18
import csv
import json
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import logging
Morgan Funtowicz's avatar
Morgan Funtowicz committed
20
import os
21
import pickle
Aymeric Augustin's avatar
Aymeric Augustin committed
22
import sys
Morgan Funtowicz's avatar
Morgan Funtowicz committed
23
from abc import ABC, abstractmethod
24
from contextlib import contextmanager
25
from os.path import abspath, exists
Aymeric Augustin's avatar
Aymeric Augustin committed
26
from typing import Dict, List, Optional, Tuple, Union
Morgan Funtowicz's avatar
Morgan Funtowicz committed
27
28
29

import numpy as np

30
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
Julien Chaumond's avatar
Julien Chaumond committed
31
32
from .configuration_distilbert import DistilBertConfig
from .configuration_roberta import RobertaConfig
33
from .configuration_utils import PretrainedConfig
Julien Chaumond's avatar
Julien Chaumond committed
34
from .configuration_xlm import XLMConfig
35
36
37
38
39
40
from .data import SquadExample, squad_convert_examples_to_features
from .file_utils import is_tf_available, is_torch_available
from .modelcard import ModelCard
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
Morgan Funtowicz's avatar
Morgan Funtowicz committed
41

Aymeric Augustin's avatar
Aymeric Augustin committed
42

Morgan Funtowicz's avatar
Morgan Funtowicz committed
43
if is_tf_available():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
44
    import tensorflow as tf
45
    from .modeling_tf_auto import (
46
47
48
49
        TFAutoModel,
        TFAutoModelForSequenceClassification,
        TFAutoModelForQuestionAnswering,
        TFAutoModelForTokenClassification,
Julien Chaumond's avatar
Julien Chaumond committed
50
        TFAutoModelWithLMHead,
51
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
52
53
54

if is_torch_available():
    import torch
55
    from .modeling_auto import (
56
57
58
59
        AutoModel,
        AutoModelForSequenceClassification,
        AutoModelForQuestionAnswering,
        AutoModelForTokenClassification,
Julien Chaumond's avatar
Julien Chaumond committed
60
        AutoModelWithLMHead,
61
    )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
62
63


64
65
logger = logging.getLogger(__name__)

66

thomwolf's avatar
thomwolf committed
67
def get_framework(model=None):
68
    """ Select framework (TensorFlow/PyTorch) to use.
69
        If both frameworks are installed and no specific model is provided, defaults to using PyTorch.
70
    """
thomwolf's avatar
thomwolf committed
71
    if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
Julien Chaumond's avatar
Julien Chaumond committed
72
        # Both framework are available but the user supplied a model class instance.
thomwolf's avatar
thomwolf committed
73
        # Try to guess which framework to use from the model classname
74
        framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
75
    elif not is_tf_available() and not is_torch_available():
Aymeric Augustin's avatar
Aymeric Augustin committed
76
        raise RuntimeError(
77
78
79
80
            "At least one of TensorFlow 2.0 or PyTorch should be installed. "
            "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
            "To install PyTorch, read the instructions at https://pytorch.org/."
        )
81
    else:
82
        # framework = 'tf' if is_tf_available() else 'pt'
83
        framework = "pt" if is_torch_available() else "tf"
thomwolf's avatar
thomwolf committed
84
85
    return framework

86

87
88
89
90
class ArgumentHandler(ABC):
    """
    Base interface for handling varargs for each Pipeline
    """
91

92
93
94
    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplementedError()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
95
96


97
98
99
100
class DefaultArgumentHandler(ArgumentHandler):
    """
    Default varargs argument parser handling parameters for each Pipeline
    """
101

102
    def __call__(self, *args, **kwargs):
103
104
105
106
        if "X" in kwargs:
            return kwargs["X"]
        elif "data" in kwargs:
            return kwargs["data"]
107
108
109
110
111
112
        elif len(args) == 1:
            if isinstance(args[0], list):
                return args[0]
            else:
                return [args[0]]
        elif len(args) > 1:
113
            return list(args)
114
        raise ValueError("Unable to infer the format of the provided data (X=, data=, ...)")
Morgan Funtowicz's avatar
Morgan Funtowicz committed
115
116


117
class PipelineDataFormat:
118
119
120
121
122
    """
    Base class for all the pipeline supported data format both for reading and writing.
    Supported data formats currently includes:
     - JSON
     - CSV
thomwolf's avatar
thomwolf committed
123
     - stdin/stdout (pipe)
124
125
126
127

    PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns
    to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
    """
128
129

    SUPPORTED_FORMATS = ["json", "csv", "pipe"]
130

thomwolf's avatar
thomwolf committed
131
    def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
thomwolf's avatar
thomwolf committed
132
133
        self.output_path = output_path
        self.input_path = input_path
134
        self.column = column.split(",") if column is not None else [""]
135
136
137
        self.is_multi_columns = len(self.column) > 1

        if self.is_multi_columns:
138
            self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column]
139

thomwolf's avatar
thomwolf committed
140
        if output_path is not None and not overwrite:
thomwolf's avatar
thomwolf committed
141
            if exists(abspath(self.output_path)):
142
                raise OSError("{} already exists on disk".format(self.output_path))
143

thomwolf's avatar
thomwolf committed
144
145
        if input_path is not None:
            if not exists(abspath(self.input_path)):
146
                raise OSError("{} doesnt exist on disk".format(self.input_path))
147
148
149
150
151
152
153

    @abstractmethod
    def __iter__(self):
        raise NotImplementedError()

    @abstractmethod
    def save(self, data: dict):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
154
155
156
157
158
        """
        Save the provided data object with the representation for the current `DataFormat`.
        :param data: data to store
        :return:
        """
159
160
        raise NotImplementedError()

161
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
162
163
164
165
166
        """
        Save the provided data object as a pickle-formatted binary data on the disk.
        :param data: data to store
        :return: (str) Path where the data has been saved
        """
thomwolf's avatar
thomwolf committed
167
        path, _ = os.path.splitext(self.output_path)
168
        binary_path = os.path.extsep.join((path, "pickle"))
169

170
        with open(binary_path, "wb+") as f_output:
171
172
173
174
            pickle.dump(data, f_output)

        return binary_path

175
    @staticmethod
176
177
178
179
    def from_str(
        format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False
    ):
        if format == "json":
thomwolf's avatar
thomwolf committed
180
            return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
181
        elif format == "csv":
thomwolf's avatar
thomwolf committed
182
            return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
183
        elif format == "pipe":
thomwolf's avatar
thomwolf committed
184
            return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
185
        else:
186
            raise KeyError("Unknown reader {} (Available reader are json/csv/pipe)".format(format))
187
188
189


class CsvPipelineDataFormat(PipelineDataFormat):
thomwolf's avatar
thomwolf committed
190
191
    def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
        super().__init__(output_path, input_path, column, overwrite=overwrite)
192
193

    def __iter__(self):
194
        with open(self.input_path, "r") as f:
195
196
197
198
199
            reader = csv.DictReader(f)
            for row in reader:
                if self.is_multi_columns:
                    yield {k: row[c] for k, c in self.column}
                else:
200
                    yield row[self.column[0]]
201
202

    def save(self, data: List[dict]):
203
        with open(self.output_path, "w") as f:
204
205
206
207
208
209
210
            if len(data) > 0:
                writer = csv.DictWriter(f, list(data[0].keys()))
                writer.writeheader()
                writer.writerows(data)


class JsonPipelineDataFormat(PipelineDataFormat):
thomwolf's avatar
thomwolf committed
211
212
    def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
        super().__init__(output_path, input_path, column, overwrite=overwrite)
213

214
        with open(input_path, "r") as f:
215
216
217
218
219
220
221
            self._entries = json.load(f)

    def __iter__(self):
        for entry in self._entries:
            if self.is_multi_columns:
                yield {k: entry[c] for k, c in self.column}
            else:
222
                yield entry[self.column[0]]
223
224

    def save(self, data: dict):
225
        with open(self.output_path, "w") as f:
226
227
228
            json.dump(data, f)


Morgan Funtowicz's avatar
Morgan Funtowicz committed
229
230
231
232
233
234
235
class PipedPipelineDataFormat(PipelineDataFormat):
    """
    Read data from piped input to the python process.
    For multi columns data, columns should separated by \t

    If columns are provided, then the output will be a dictionary with {column_x: value_x}
    """
236

Morgan Funtowicz's avatar
Morgan Funtowicz committed
237
238
239
    def __iter__(self):
        for line in sys.stdin:
            # Split for multi-columns
240
            if "\t" in line:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
241

242
                line = line.split("\t")
Morgan Funtowicz's avatar
Morgan Funtowicz committed
243
244
245
246
247
248
249
250
251
252
253
254
255
                if self.column:
                    # Dictionary to map arguments
                    yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
                else:
                    yield tuple(line)

            # No dictionary to map arguments
            else:
                yield line

    def save(self, data: dict):
        print(data)

256
    def save_binary(self, data: Union[dict, List[dict]]) -> str:
thomwolf's avatar
thomwolf committed
257
        if self.output_path is None:
258
            raise KeyError(
259
260
                "When using piped input on pipeline outputting large object requires an output file path. "
                "Please provide such output path through --output argument."
261
262
263
264
            )

        return super().save_binary(data)

Morgan Funtowicz's avatar
Morgan Funtowicz committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

class _ScikitCompat(ABC):
    """
    Interface layer for the Scikit and Keras compatibility.
    """

    @abstractmethod
    def transform(self, X):
        raise NotImplementedError()

    @abstractmethod
    def predict(self, X):
        raise NotImplementedError()


280
class Pipeline(_ScikitCompat):
281
282
283
284
    """
    Base class implementing pipelined operations.
    Pipeline workflow is defined as a sequence of the following operations:
        Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
Morgan Funtowicz's avatar
Morgan Funtowicz committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

    Pipeline supports running on CPU or GPU through the device argument. Users can specify
    device argument as an integer, -1 meaning "CPU", >= 0 referring the CUDA device ordinal.

    Some pipeline, like for instance FeatureExtractionPipeline ('feature-extraction') outputs large
    tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
    provide the binary_output constructor argument. If set to True, the output will be stored in the
    pickle format.

    Arguments:
        **model**: ``(str, PretrainedModel, TFPretrainedModel)``:
            Reference to the model to use through this pipeline.

        **tokenizer**: ``(str, PreTrainedTokenizer)``:
            Reference to the tokenizer to use through this pipeline.

        **args_parser**: ``ArgumentHandler``:
            Reference to the object in charge of parsing supplied pipeline parameters.

        **device**: ``int``:
            Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
            on the associated CUDA device id.

        **binary_output** ``bool`` (default: False):
            Flag indicating if the output the pipeline should happen in a binary format (i.e. pickle) or as raw text.

    Return:
        Pipeline returns list or dictionary depending on:
         - Does the user provided multiple sample
         - The pipeline expose multiple fields in the output object

    Examples:
        nlp = pipeline('ner')
        nlp = pipeline('ner', model='...', config='...', tokenizer='...')
        nlp = NerPipeline(model='...', config='...', tokenizer='...')
        nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
321
    """
thomwolf's avatar
thomwolf committed
322
323
324

    default_input_names = None

325
326
327
328
    def __init__(
        self,
        model,
        tokenizer: PreTrainedTokenizer = None,
329
        modelcard: Optional[ModelCard] = None,
330
331
332
333
334
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
    ):
335

thomwolf's avatar
thomwolf committed
336
337
338
        if framework is None:
            framework = get_framework()

339
340
        self.model = model
        self.tokenizer = tokenizer
341
        self.modelcard = modelcard
thomwolf's avatar
thomwolf committed
342
        self.framework = framework
343
        self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
344
        self.binary_output = binary_output
345
346
        self._args_parser = args_parser or DefaultArgumentHandler()

347
        # Special handling
348
349
        if self.framework == "pt" and self.device.type == "cuda":
            self.model = self.model.to(self.device)
350

351
    def save_pretrained(self, save_directory):
352
353
354
        """
        Save the pipeline's model and tokenizer to the specified save_directory
        """
355
356
357
358
359
360
        if not os.path.isdir(save_directory):
            logger.error("Provided path ({}) should be a directory".format(save_directory))
            return

        self.model.save_pretrained(save_directory)
        self.tokenizer.save_pretrained(save_directory)
361
362
        if self.modelcard is not None:
            self.modelcard.save_pretrained(save_directory)
363
364

    def transform(self, X):
365
366
367
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
368
369
370
        return self(X=X)

    def predict(self, X):
371
372
373
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
374
        return self(X=X)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
375

376
377
    @contextmanager
    def device_placement(self):
378
379
380
381
382
383
384
385
386
387
388
        """
        Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
        example:
            # Explicitly ask for tensor allocation on CUDA device :0
            nlp = pipeline(..., device=0)
            with nlp.device_placement():
                # Every framework specific tensor allocation will be done on the request device
                output = nlp(...)
        Returns:
            Context manager
        """
389
390
        if self.framework == "tf":
            with tf.device("/CPU:0" if self.device == -1 else "/device:GPU:{}".format(self.device)):
391
392
                yield
        else:
393
            if self.device.type == "cuda":
394
                torch.cuda.set_device(self.device)
395

396
            yield
397

398
399
400
401
402
403
404
405
    def ensure_tensor_on_device(self, **inputs):
        """
        Ensure PyTorch tensors are on the specified device.
        :param inputs:
        :return:
        """
        return {name: tensor.to(self.device) for name, tensor in inputs.items()}

406
407
408
409
410
411
412
    def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict:
        """
        Generates the input dictionary with model-specific parameters.

        Returns:
            dict holding all the required parameters for model's forward
        """
413
        args = ["input_ids", "attention_mask"]
414

Julien Chaumond's avatar
Julien Chaumond committed
415
        if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig)):
416
            args += ["token_type_ids"]
417

Morgan Funtowicz's avatar
Morgan Funtowicz committed
418
419
420
        # PR #1548 (CLI) There is an issue with attention_mask
        # if 'xlnet' in model_type or 'xlm' in model_type:
        #     args += ['cls_index', 'p_mask']
421
422
423
424
425
426

        if isinstance(features, dict):
            return {k: features[k] for k in args}
        else:
            return {k: [feature[k] for feature in features] for k in args}

Julien Chaumond's avatar
Julien Chaumond committed
427
428
429
430
    def _parse_and_tokenize(self, *texts, **kwargs):
        """
        Parse arguments and tokenize
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
431
432
        # Parse arguments
        inputs = self._args_parser(*texts, **kwargs)
433
434
435
        inputs = self.tokenizer.batch_encode_plus(
            inputs, add_special_tokens=True, return_tensors=self.framework, max_length=self.tokenizer.max_len
        )
Morgan Funtowicz's avatar
Morgan Funtowicz committed
436

437
438
        # Filter out features not available on specific models
        inputs = self.inputs_for_model(inputs)
Julien Chaumond's avatar
Julien Chaumond committed
439
440
441
442
443

        return inputs

    def __call__(self, *texts, **kwargs):
        inputs = self._parse_and_tokenize(*texts, **kwargs)
444
        return self._forward(inputs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
445

Julien Chaumond's avatar
Julien Chaumond committed
446
    def _forward(self, inputs, return_tensors=False):
447
448
449
450
        """
        Internal framework specific forward dispatching.
        Args:
            inputs: dict holding all the keyworded arguments for required by the model forward method.
Julien Chaumond's avatar
Julien Chaumond committed
451
            return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array.
452
453
454
        Returns:
            Numpy array
        """
455
456
457
458
459
460
461
462
463
        # Encode for forward
        with self.device_placement():
            if self.framework == "tf":
                # TODO trace model
                predictions = self.model(inputs, training=False)[0]
            else:
                with torch.no_grad():
                    inputs = self.ensure_tensor_on_device(**inputs)
                    predictions = self.model(**inputs)[0].cpu()
464

Julien Chaumond's avatar
Julien Chaumond committed
465
466
467
468
        if return_tensors:
            return predictions
        else:
            return predictions.numpy()
469
470
471


class FeatureExtractionPipeline(Pipeline):
472
473
474
    """
    Feature extraction pipeline using Model head.
    """
475

476
477
478
479
    def __init__(
        self,
        model,
        tokenizer: PreTrainedTokenizer = None,
480
        modelcard: Optional[ModelCard] = None,
481
482
483
484
485
486
487
488
489
490
491
492
493
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
        )
494

495
496
    def __call__(self, *args, **kwargs):
        return super().__call__(*args, **kwargs).tolist()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
497
498


Morgan Funtowicz's avatar
Morgan Funtowicz committed
499
class TextClassificationPipeline(Pipeline):
500
501
502
    """
    Text classification pipeline using ModelForTextClassification head.
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
503

504
    def __call__(self, *args, **kwargs):
505
506
        outputs = super().__call__(*args, **kwargs)
        scores = np.exp(outputs) / np.exp(outputs).sum(-1)
507
        return [{"label": self.model.config.id2label[item.argmax()], "score": item.max()} for item in scores]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
508
509


Julien Chaumond's avatar
Julien Chaumond committed
510
511
512
513
514
515
516
517
518
class FillMaskPipeline(Pipeline):
    """
    Masked language modeling prediction pipeline using ModelWithLMHead head.
    """

    def __init__(
        self,
        model,
        tokenizer: PreTrainedTokenizer = None,
519
        modelcard: Optional[ModelCard] = None,
Julien Chaumond's avatar
Julien Chaumond committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        topk=5,
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=True,
        )

        self.topk = topk

    def __call__(self, *args, **kwargs):
        inputs = self._parse_and_tokenize(*args, **kwargs)
        outputs = self._forward(inputs, return_tensors=True)

        results = []
        batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)

        for i in range(batch_size):
            input_ids = inputs["input_ids"][i]
            result = []

            if self.framework == "tf":
                masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy().item()
                logits = outputs[i, masked_index, :]
                probs = tf.nn.softmax(logits)
                topk = tf.math.top_k(probs, k=self.topk)
                values, predictions = topk.values.numpy(), topk.indices.numpy()
            else:
                masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero().item()
                logits = outputs[i, masked_index, :]
                probs = logits.softmax(dim=0)
                values, predictions = probs.topk(self.topk)

            for v, p in zip(values.tolist(), predictions.tolist()):
                tokens = input_ids.numpy()
                tokens[masked_index] = p
                # Filter padding out:
                tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
                result.append({"sequence": self.tokenizer.decode(tokens), "score": v, "token": p})

            # Append
            results += [result]

        if len(results) == 1:
            return results[0]
        return results


Morgan Funtowicz's avatar
Morgan Funtowicz committed
575
class NerPipeline(Pipeline):
576
577
578
    """
    Named Entity Recognition pipeline using ModelForTokenClassification head.
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
579

580
581
582
583
584
585
    default_input_names = "sequences"

    def __init__(
        self,
        model,
        tokenizer: PreTrainedTokenizer = None,
586
        modelcard: Optional[ModelCard] = None,
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        framework: Optional[str] = None,
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
        ignore_labels=["O"],
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device,
            binary_output=binary_output,
        )
602
603

        self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
604
        self.ignore_labels = ignore_labels
605

Morgan Funtowicz's avatar
Morgan Funtowicz committed
606
    def __call__(self, *texts, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
607
608
        inputs = self._args_parser(*texts, **kwargs)
        answers = []
609
        for sentence in inputs:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
610

611
612
            # Manage correct placement of the tensors
            with self.device_placement():
Morgan Funtowicz's avatar
Morgan Funtowicz committed
613

614
                tokens = self.tokenizer.encode_plus(
615
616
                    sentence,
                    return_attention_mask=False,
thomwolf's avatar
thomwolf committed
617
                    return_tensors=self.framework,
618
                    max_length=self.tokenizer.max_len,
619
                )
620
621

                # Forward
622
                if self.framework == "tf":
Morgan Funtowicz's avatar
Morgan Funtowicz committed
623
                    entities = self.model(tokens)[0][0].numpy()
624
                    input_ids = tokens["input_ids"].numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
625
                else:
626
                    with torch.no_grad():
627
                        tokens = self.ensure_tensor_on_device(**tokens)
628
                        entities = self.model(**tokens)[0][0].cpu().numpy()
629
                        input_ids = tokens["input_ids"].cpu().numpy()[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
630

thomwolf's avatar
thomwolf committed
631
632
            score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
            labels_idx = score.argmax(axis=-1)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
633

thomwolf's avatar
thomwolf committed
634
635
636
            answer = []
            for idx, label_idx in enumerate(labels_idx):
                if self.model.config.id2label[label_idx] not in self.ignore_labels:
637
638
639
640
641
642
643
                    answer += [
                        {
                            "word": self.tokenizer.decode([int(input_ids[idx])]),
                            "score": score[idx][label_idx].item(),
                            "entity": self.model.config.id2label[label_idx],
                        }
                    ]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
644
645
646

            # Append
            answers += [answer]
thomwolf's avatar
thomwolf committed
647
648
        if len(answers) == 1:
            return answers[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
649
650
651
        return answers


652
653
654
655
656
657
658
659
class QuestionAnsweringArgumentHandler(ArgumentHandler):
    """
    QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped
    to internal SquadExample / SquadFeature structures.

    QuestionAnsweringArgumentHandler manages all the possible to create SquadExample from the command-line supplied
    arguments.
    """
660

661
662
663
664
    def __call__(self, *args, **kwargs):
        # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
        if args is not None and len(args) > 0:
            if len(args) == 1:
665
                kwargs["X"] = args[0]
666
            else:
667
                kwargs["X"] = list(args)
668

Morgan Funtowicz's avatar
Morgan Funtowicz committed
669
670
        # Generic compatibility with sklearn and Keras
        # Batched data
671
672
        if "X" in kwargs or "data" in kwargs:
            inputs = kwargs["X"] if "X" in kwargs else kwargs["data"]
673

Morgan Funtowicz's avatar
Morgan Funtowicz committed
674
675
676
677
678
            if isinstance(inputs, dict):
                inputs = [inputs]
            else:
                # Copy to avoid overriding arguments
                inputs = [i for i in inputs]
679

Morgan Funtowicz's avatar
Morgan Funtowicz committed
680
            for i, item in enumerate(inputs):
681
                if isinstance(item, dict):
682
683
                    if any(k not in item for k in ["question", "context"]):
                        raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
684

Morgan Funtowicz's avatar
Morgan Funtowicz committed
685
686
687
                    inputs[i] = QuestionAnsweringPipeline.create_sample(**item)

                elif not isinstance(item, SquadExample):
688
                    raise ValueError(
689
690
691
                        "{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format(
                            "X" if "X" in kwargs else "data"
                        )
692
693
694
                    )

            # Tabular input
695
696
697
        elif "question" in kwargs and "context" in kwargs:
            if isinstance(kwargs["question"], str):
                kwargs["question"] = [kwargs["question"]]
698

699
700
            if isinstance(kwargs["context"], str):
                kwargs["context"] = [kwargs["context"]]
701

702
703
704
            inputs = [
                QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"])
            ]
705
        else:
706
            raise ValueError("Unknown arguments {}".format(kwargs))
707
708
709
710
711
712
713

        if not isinstance(inputs, list):
            inputs = [inputs]

        return inputs


Morgan Funtowicz's avatar
Morgan Funtowicz committed
714
715
class QuestionAnsweringPipeline(Pipeline):
    """
716
    Question Answering pipeline using ModelForQuestionAnswering head.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
717
718
    """

719
720
721
722
723
724
    default_input_names = "question,context"

    def __init__(
        self,
        model,
        tokenizer: Optional[PreTrainedTokenizer],
725
        modelcard: Optional[ModelCard] = None,
726
727
728
729
730
731
732
733
734
735
736
        framework: Optional[str] = None,
        device: int = -1,
        **kwargs
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=QuestionAnsweringArgumentHandler(),
            device=device,
737
            **kwargs,
738
        )
thomwolf's avatar
thomwolf committed
739

Morgan Funtowicz's avatar
Morgan Funtowicz committed
740
    @staticmethod
741
742
743
    def create_sample(
        question: Union[str, List[str]], context: Union[str, List[str]]
    ) -> Union[SquadExample, List[SquadExample]]:
744
745
746
747
        """
        QuestionAnsweringPipeline leverages the SquadExample/SquadFeatures internally.
        This helper method encapsulate all the logic for converting question(s) and context(s) to SquadExample(s).
        We currently support extractive question answering.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
748
        Arguments:
749
750
             question: (str, List[str]) The question to be ask for the associated context
             context: (str, List[str]) The context in which we will look for the answer.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
751
752
753

        Returns:
            SquadExample initialized with the corresponding question and context.
754
755
        """
        if isinstance(question, list):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
756
757
758
759
760
            return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
        else:
            return SquadExample(None, question, context, None, None, None)

    def __call__(self, *texts, **kwargs):
761
762
763
764
765
766
767
768
769
770
771
772
773
774
        """
        Args:
            We support multiple use-cases, the following are exclusive:
            X: sequence of SquadExample
            data: sequence of SquadExample
            question: (str, List[str]), batch of question(s) to map along with context
            context: (str, List[str]), batch of context(s) associated with the provided question keyword argument
        Returns:
            dict: {'answer': str, 'score": float, 'start": int, "end": int}
            answer: the textual answer in the intial context
            score: the score the current answer scored for the model
            start: the character index in the original string corresponding to the beginning of the answer' span
            end: the character index in the original string corresponding to the ending of the answer' span
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
775
        # Set defaults values
776
777
778
779
780
        kwargs.setdefault("topk", 1)
        kwargs.setdefault("doc_stride", 128)
        kwargs.setdefault("max_answer_len", 15)
        kwargs.setdefault("max_seq_len", 384)
        kwargs.setdefault("max_question_len", 64)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
781

782
783
        if kwargs["topk"] < 1:
            raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
784

785
786
        if kwargs["max_answer_len"] < 1:
            raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
Morgan Funtowicz's avatar
Morgan Funtowicz committed
787
788

        # Convert inputs to features
789
        examples = self._args_parser(*texts, **kwargs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
790
791
792
793
794
795
796
797
798
799
800
        features_list = [
            squad_convert_examples_to_features(
                [example],
                self.tokenizer,
                kwargs["max_seq_len"],
                kwargs["doc_stride"],
                kwargs["max_question_len"],
                False,
            )
            for example in examples
        ]
Rishabh Manoj's avatar
Rishabh Manoj committed
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
        all_answers = []
        for features, example in zip(features_list, examples):
            fw_args = self.inputs_for_model([f.__dict__ for f in features])

            # Manage tensor allocation on correct device
            with self.device_placement():
                if self.framework == "tf":
                    fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
                    start, end = self.model(fw_args)
                    start, end = start.numpy(), end.numpy()
                else:
                    with torch.no_grad():
                        # Retrieve the score for the context tokens only (removing question tokens)
                        fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
                        start, end = self.model(**fw_args)
                        start, end = start.cpu().numpy(), end.cpu().numpy()

            answers = []
            for (feature, start_, end_) in zip(features, start, end):
                # Normalize logits and spans to retrieve the answer
                start_ = np.exp(start_) / np.sum(np.exp(start_))
                end_ = np.exp(end_) / np.sum(np.exp(end_))

                # Mask padding and question
Morgan Funtowicz's avatar
Morgan Funtowicz committed
825
826
827
828
                start_, end_ = (
                    start_ * np.abs(np.array(feature.p_mask) - 1),
                    end_ * np.abs(np.array(feature.p_mask) - 1),
                )
Rishabh Manoj's avatar
Rishabh Manoj committed
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848

                # TODO : What happens if not possible
                # Mask CLS
                start_[0] = end_[0] = 0

                starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
                char_to_word = np.array(example.char_to_word_offset)

                # Convert the answer (tokens) back to the original text
                answers += [
                    {
                        "score": score.item(),
                        "start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
                        "end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
                        "answer": " ".join(
                            example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
                        ),
                    }
                    for s, e, score in zip(starts, ends, scores)
                ]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
849
850
851
            answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
            all_answers += answers

Rishabh Manoj's avatar
Rishabh Manoj committed
852
        if len(all_answers) == 1:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
853
            return all_answers[0]
Rishabh Manoj's avatar
Rishabh Manoj committed
854
        return all_answers
Morgan Funtowicz's avatar
Morgan Funtowicz committed
855
856

    def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
857
858
859
860
861
862
863
864
865
866
867
868
869
        """
        Take the output of any QuestionAnswering head and will generate probalities for each span to be
        the actual answer.
        In addition, it filters out some unwanted/impossible cases like answer len being greater than
        max_answer_len or answer end position being before the starting position.
        The method supports output the k-best answer through the topk argument.

        Args:
            start: numpy array, holding individual start probabilities for each token
            end: numpy array, holding individual end probabilities for each token
            topk: int, indicates how many possible answer span(s) to extract from the model's output
            max_answer_len: int, maximum size of the answer to extract from the model's output
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
        # Ensure we have batch axis
        if start.ndim == 1:
            start = start[None]

        if end.ndim == 1:
            end = end[None]

        # Compute the score of each tuple(start, end) to be the real answer
        outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))

        # Remove candidate with end < start and end - start > max_answer_len
        candidates = np.tril(np.triu(outer), max_answer_len - 1)

        #  Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
        scores_flat = candidates.flatten()
        if topk == 1:
            idx_sort = [np.argmax(scores_flat)]
        elif len(scores_flat) < topk:
            idx_sort = np.argsort(-scores_flat)
        else:
            idx = np.argpartition(-scores_flat, topk)[0:topk]
            idx_sort = idx[np.argsort(-scores_flat[idx])]

        start, end = np.unravel_index(idx_sort, candidates.shape)[1:]
        return start, end, candidates[0, start, end]

    def span_to_answer(self, text: str, start: int, end: int):
897
898
899
900
901
902
903
904
905
906
907
908
        """
        When decoding from token probalities, this method maps token indexes to actual word in
        the initial context.

        Args:
            text: str, the actual context to extract the answer from
            start: int, starting answer token index
            end: int, ending answer token index

        Returns:
            dict: {'answer': str, 'start': int, 'end': int}
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
        words = []
        token_idx = char_start_idx = char_end_idx = chars_idx = 0

        for i, word in enumerate(text.split(" ")):
            token = self.tokenizer.tokenize(word)

            # Append words if they are in the span
            if start <= token_idx <= end:
                if token_idx == start:
                    char_start_idx = chars_idx

                if token_idx == end:
                    char_end_idx = chars_idx + len(word)

                words += [word]

            # Stop if we went over the end of the answer
            if token_idx > end:
                break

            # Append the subtokenization length to the running index
            token_idx += len(token)
            chars_idx += len(word) + 1

        # Join text with spaces
934
        return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}
Morgan Funtowicz's avatar
Morgan Funtowicz committed
935
936
937
938


# Register all the supported task here
SUPPORTED_TASKS = {
939
940
941
942
943
    "feature-extraction": {
        "impl": FeatureExtractionPipeline,
        "tf": TFAutoModel if is_tf_available() else None,
        "pt": AutoModel if is_torch_available() else None,
        "default": {
944
            "model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"},
945
            "config": None,
946
            "tokenizer": "distilbert-base-cased",
947
        },
948
    },
949
950
951
952
953
954
    "sentiment-analysis": {
        "impl": TextClassificationPipeline,
        "tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
        "pt": AutoModelForSequenceClassification if is_torch_available() else None,
        "default": {
            "model": {
955
956
                "pt": "distilbert-base-uncased-finetuned-sst-2-english",
                "tf": "distilbert-base-uncased-finetuned-sst-2-english",
957
            },
958
            "config": "distilbert-base-uncased-finetuned-sst-2-english",
959
960
            "tokenizer": "distilbert-base-uncased",
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
961
    },
962
963
964
965
966
967
    "ner": {
        "impl": NerPipeline,
        "tf": TFAutoModelForTokenClassification if is_tf_available() else None,
        "pt": AutoModelForTokenClassification if is_torch_available() else None,
        "default": {
            "model": {
Julien Chaumond's avatar
Julien Chaumond committed
968
969
                "pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
                "tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
970
            },
Julien Chaumond's avatar
Julien Chaumond committed
971
            "config": "dbmdz/bert-large-cased-finetuned-conll03-english",
972
973
            "tokenizer": "bert-large-cased",
        },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
974
    },
975
976
977
978
979
    "question-answering": {
        "impl": QuestionAnsweringPipeline,
        "tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
        "pt": AutoModelForQuestionAnswering if is_torch_available() else None,
        "default": {
Lysandre's avatar
E231  
Lysandre committed
980
            "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
981
            "config": None,
982
            "tokenizer": "distilbert-base-cased",
983
984
        },
    },
Julien Chaumond's avatar
Julien Chaumond committed
985
986
987
988
989
990
991
992
993
994
    "fill-mask": {
        "impl": FillMaskPipeline,
        "tf": TFAutoModelWithLMHead if is_tf_available() else None,
        "pt": AutoModelWithLMHead if is_torch_available() else None,
        "default": {
            "model": {"pt": "distilroberta-base", "tf": "distilroberta-base"},
            "config": None,
            "tokenizer": "distilroberta-base",
        },
    },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
995
996
997
}


998
999
1000
1001
1002
1003
1004
1005
def pipeline(
    task: str,
    model: Optional = None,
    config: Optional[Union[str, PretrainedConfig]] = None,
    tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
    modelcard: Optional[Union[str, ModelCard]] = None,
    **kwargs
) -> Pipeline:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1006
    """
1007
1008
1009
1010
1011
    Utility factory method to build a pipeline.
    Pipeline are made of:
        A Tokenizer instance in charge of mapping raw textual input to token
        A Model instance
        Some (optional) post processing for enhancing model's output
1012
1013

    Examples:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1014
        pipeline('sentiment-analysis')
1015
        pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased')
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1016
        pipeline('ner', model=AutoModel.from_pretrained(...), tokenizer=AutoTokenizer.from_pretrained(...)
Julien Chaumond's avatar
Julien Chaumond committed
1017
        pipeline('ner', model='dbmdz/bert-large-cased-finetuned-conll03-english', tokenizer='bert-base-cased')
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1018
        pipeline('ner', model='https://...pytorch-model.bin', config='https://...config.json', tokenizer='bert-base-cased')
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1019
    """
1020
    # Retrieve the task
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1021
1022
1023
    if task not in SUPPORTED_TASKS:
        raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))

thomwolf's avatar
thomwolf committed
1024
    framework = get_framework(model)
1025

Morgan Funtowicz's avatar
Morgan Funtowicz committed
1026
    targeted_task = SUPPORTED_TASKS[task]
1027
    task, model_class = targeted_task["impl"], targeted_task[framework]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1028

1029
    # Use default model/config/tokenizer for the task if no model is provided
1030
    if model is None:
1031
        models, config, tokenizer = tuple(targeted_task["default"].values())
1032
        model = models[framework]
1033

1034
1035
    # Try to infer tokenizer from model or config name (if provided as str)
    if tokenizer is None:
thomwolf's avatar
thomwolf committed
1036
        if isinstance(model, str) and model in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
1037
            tokenizer = model
thomwolf's avatar
thomwolf committed
1038
        elif isinstance(config, str) and config in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
1039
1040
1041
            tokenizer = config
        else:
            # Impossible to guest what is the right tokenizer here
1042
1043
1044
1045
            raise Exception(
                "Impossible to guess which tokenizer to use. "
                "Please provided a PretrainedTokenizer class or a path/url/shortcut name to a pretrained tokenizer."
            )
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055

    # Try to infer modelcard from model or config name (if provided as str)
    if modelcard is None:
        # Try to fallback on one of the provided string for model or config (will replace the suffix)
        if isinstance(model, str):
            modelcard = model
        elif isinstance(config, str):
            modelcard = config

    # Instantiate tokenizer if needed
Aymeric Augustin's avatar
Aymeric Augustin committed
1056
    if isinstance(tokenizer, str):
1057
1058
1059
1060
1061
        tokenizer = AutoTokenizer.from_pretrained(tokenizer)

    # Instantiate config if needed
    if isinstance(config, str):
        config = AutoConfig.from_pretrained(config)
1062

thomwolf's avatar
thomwolf committed
1063
1064
1065
1066
    # Instantiate modelcard if needed
    if isinstance(modelcard, str):
        modelcard = ModelCard.from_pretrained(modelcard)

1067
    # Instantiate model if needed
1068
    if isinstance(model, str):
1069
1070
        # Handle transparent TF/PT model conversion
        model_kwargs = {}
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
        if framework == "pt" and model.endswith(".h5"):
            model_kwargs["from_tf"] = True
            logger.warning(
                "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
                "Trying to load the model with PyTorch."
            )
        elif framework == "tf" and model.endswith(".bin"):
            model_kwargs["from_pt"] = True
            logger.warning(
                "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
                "Trying to load the model with Tensorflow."
            )
1083
        model = model_class.from_pretrained(model, config=config, **model_kwargs)
1084

thomwolf's avatar
thomwolf committed
1085
    return task(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, **kwargs)