pipelines.py 26.2 KB
Newer Older
Morgan Funtowicz's avatar
Morgan Funtowicz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals

17
18
import csv
import json
Morgan Funtowicz's avatar
Morgan Funtowicz committed
19
20
import os
from abc import ABC, abstractmethod
21
from contextlib import contextmanager
Morgan Funtowicz's avatar
Morgan Funtowicz committed
22
from itertools import groupby
23
from os.path import abspath, exists
Morgan Funtowicz's avatar
Morgan Funtowicz committed
24
25
26
27
28
29
30
31
from typing import Union, Optional, Tuple, List, Dict

import numpy as np

from transformers import AutoTokenizer, PreTrainedTokenizer, PretrainedConfig, \
    SquadExample, squad_convert_examples_to_features, is_tf_available, is_torch_available, logger

if is_tf_available():
32
33
    from transformers import TFAutoModel, TFAutoModelForSequenceClassification, \
        TFAutoModelForQuestionAnswering, TFAutoModelForTokenClassification
Morgan Funtowicz's avatar
Morgan Funtowicz committed
34
35
36

if is_torch_available():
    import torch
37
38
    from transformers import AutoModel, AutoModelForSequenceClassification, \
        AutoModelForQuestionAnswering, AutoModelForTokenClassification
Morgan Funtowicz's avatar
Morgan Funtowicz committed
39
40


41
42
43
44
45
46
47
class ArgumentHandler(ABC):
    """
    Base interface for handling varargs for each Pipeline
    """
    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplementedError()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
48
49


50
51
52
53
54
55
56
57
58
class DefaultArgumentHandler(ArgumentHandler):
    """
    Default varargs argument parser handling parameters for each Pipeline
    """
    def __call__(self, *args, **kwargs):
        if 'X' in kwargs:
            return kwargs['X']
        elif 'data' in kwargs:
            return kwargs['data']
59
60
61
62
63
64
        elif len(args) == 1:
            if isinstance(args[0], list):
                return args[0]
            else:
                return [args[0]]
        elif len(args) > 1:
65
66
            return list(args)
        raise ValueError('Unable to infer the format of the provided data (X=, data=, ...)')
Morgan Funtowicz's avatar
Morgan Funtowicz committed
67
68


69
class PipelineDataFormat:
70
71
72
73
74
75
76
77
78
    """
    Base class for all the pipeline supported data format both for reading and writing.
    Supported data formats currently includes:
     - JSON
     - CSV

    PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns
    to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
79
    SUPPORTED_FORMATS = ['json', 'csv', 'pipe']
80

Morgan Funtowicz's avatar
Morgan Funtowicz committed
81
    def __init__(self, output: Optional[str], path: Optional[str], column: Optional[str]):
82
83
        self.output = output
        self.path = path
Morgan Funtowicz's avatar
Morgan Funtowicz committed
84
        self.column = column.split(',') if column else ['']
85
86
87
88
89
        self.is_multi_columns = len(self.column) > 1

        if self.is_multi_columns:
            self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]

90
91
92
        if output is not None:
            if exists(abspath(self.output)):
                raise OSError('{} already exists on disk'.format(self.output))
93

94
95
            if not exists(abspath(self.path)):
                raise OSError('{} doesnt exist on disk'.format(self.path))
96
97
98
99
100
101
102
103
104
105

    @abstractmethod
    def __iter__(self):
        raise NotImplementedError()

    @abstractmethod
    def save(self, data: dict):
        raise NotImplementedError()

    @staticmethod
Morgan Funtowicz's avatar
Morgan Funtowicz committed
106
    def from_str(name: str, output: Optional[str], path: Optional[str], column: Optional[str]):
107
108
109
110
        if name == 'json':
            return JsonPipelineDataFormat(output, path, column)
        elif name == 'csv':
            return CsvPipelineDataFormat(output, path, column)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
111
112
        elif name == 'pipe':
            return PipedPipelineDataFormat(output, path, column)
113
        else:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
114
            raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(name))
115
116
117


class CsvPipelineDataFormat(PipelineDataFormat):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
118
    def __init__(self, output: Optional[str], path: Optional[str], column: Optional[str]):
119
120
121
122
123
124
125
126
127
        super().__init__(output, path, column)

    def __iter__(self):
        with open(self.path, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                if self.is_multi_columns:
                    yield {k: row[c] for k, c in self.column}
                else:
128
                    yield row[self.column[0]]
129
130
131
132
133
134
135
136
137
138

    def save(self, data: List[dict]):
        with open(self.output, 'w') as f:
            if len(data) > 0:
                writer = csv.DictWriter(f, list(data[0].keys()))
                writer.writeheader()
                writer.writerows(data)


class JsonPipelineDataFormat(PipelineDataFormat):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
139
    def __init__(self, output: Optional[str], path: Optional[str], column: Optional[str]):
140
141
142
143
144
145
146
147
148
149
        super().__init__(output, path, column)

        with open(path, 'r') as f:
            self._entries = json.load(f)

    def __iter__(self):
        for entry in self._entries:
            if self.is_multi_columns:
                yield {k: entry[c] for k, c in self.column}
            else:
150
                yield entry[self.column[0]]
151
152
153
154
155
156

    def save(self, data: dict):
        with open(self.output, 'w') as f:
            json.dump(data, f)


Morgan Funtowicz's avatar
Morgan Funtowicz committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class PipedPipelineDataFormat(PipelineDataFormat):
    """
    Read data from piped input to the python process.
    For multi columns data, columns should separated by \t

    If columns are provided, then the output will be a dictionary with {column_x: value_x}
    """
    def __iter__(self):
        import sys
        for line in sys.stdin:

            # Split for multi-columns
            if '\t' in line:

                line = line.split('\t')
                if self.column:
                    # Dictionary to map arguments
                    yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
                else:
                    yield tuple(line)

            # No dictionary to map arguments
            else:
                print(line)
                yield line

    def save(self, data: dict):
        print(data)


class _ScikitCompat(ABC):
    """
    Interface layer for the Scikit and Keras compatibility.
    """

    @abstractmethod
    def transform(self, X):
        raise NotImplementedError()

    @abstractmethod
    def predict(self, X):
        raise NotImplementedError()


201
class Pipeline(_ScikitCompat):
202
203
204
205
206
    """
    Base class implementing pipelined operations.
    Pipeline workflow is defined as a sequence of the following operations:
        Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
    """
207
208
209
    def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
                 args_parser: ArgumentHandler = None, device: int = -1, **kwargs):

210
211
        self.model = model
        self.tokenizer = tokenizer
212
        self.device = device
213
214
        self._args_parser = args_parser or DefaultArgumentHandler()

215
216
217
218
        # Special handling
        if self.device >= 0 and not is_tf_available():
            self.model = self.model.to('cuda:{}'.format(self.device))

219
    def save_pretrained(self, save_directory):
220
221
222
        """
        Save the pipeline's model and tokenizer to the specified save_directory
        """
223
224
225
226
227
228
229
230
        if not os.path.isdir(save_directory):
            logger.error("Provided path ({}) should be a directory".format(save_directory))
            return

        self.model.save_pretrained(save_directory)
        self.tokenizer.save_pretrained(save_directory)

    def transform(self, X):
231
232
233
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        """
234
235
236
        return self(X=X)

    def predict(self, X):
237
238
239
240
        """
        Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
        Se
        """
241
        return self(X=X)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
242

243
244
    @contextmanager
    def device_placement(self):
245
246
247
248
249
250
251
252
253
254
255
        """
        Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
        example:
            # Explicitly ask for tensor allocation on CUDA device :0
            nlp = pipeline(..., device=0)
            with nlp.device_placement():
                # Every framework specific tensor allocation will be done on the request device
                output = nlp(...)
        Returns:
            Context manager
        """
256
257
258
259
260
261
262
263
        if is_tf_available():
            import tensorflow as tf
            with tf.device('/CPU:0' if self.device == -1 else '/device:GPU:{}'.format(self.device)):
                yield
        else:
            import torch
            if self.device >= 0:
                torch.cuda.set_device(self.device)
264

265
            yield
266

Morgan Funtowicz's avatar
Morgan Funtowicz committed
267
268
269
270
271
272
273
274
275
276
277
278
    def __call__(self, *texts, **kwargs):
        # Parse arguments
        inputs = self._args_parser(*texts, **kwargs)

        # Encode for forward
        with self.device_placement():
            inputs = self.tokenizer.batch_encode_plus(
                inputs, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
            )

            return self._forward(inputs)

279
    def _forward(self, inputs):
280
281
282
283
284
285
286
        """
        Internal framework specific forward dispatching.
        Args:
            inputs: dict holding all the keyworded arguments for required by the model forward method.
        Returns:
            Numpy array
        """
287
288
289
290
291
292
293
294
        if is_tf_available():
            # TODO trace model
            predictions = self.model(inputs)[0]
        else:
            import torch
            with torch.no_grad():
                predictions = self.model(**inputs)[0]

295
296
297
298
        return predictions.numpy()


class FeatureExtractionPipeline(Pipeline):
299
300
301
    """
    Feature extraction pipeline using Model head.
    """
302
303
    def __call__(self, *args, **kwargs):
        return super().__call__(*args, **kwargs).tolist()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
304
305


Morgan Funtowicz's avatar
Morgan Funtowicz committed
306
class TextClassificationPipeline(Pipeline):
307
308
309
    """
    Text classification pipeline using ModelForTextClassification head.
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
310

311
    def __call__(self, *args, **kwargs):
312
313
314
        outputs = super().__call__(*args, **kwargs)
        scores = np.exp(outputs) / np.exp(outputs).sum(-1)
        return [{'label': self.model.config.id2label[item.argmax()], 'score': item.max()} for item in scores]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
315
316
317


class NerPipeline(Pipeline):
318
319
320
    """
    Named Entity Recognition pipeline using ModelForTokenClassification head.
    """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
321
322

    def __call__(self, *texts, **kwargs):
323
        inputs, answers = self._args_parser(*texts, **kwargs), []
324
        for sentence in inputs:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
325
326
327
328
329
330
331

            # Ugly token to word idx mapping (for now)
            token_to_word, words = [], sentence.split(' ')
            for i, w in enumerate(words):
                tokens = self.tokenizer.tokenize(w)
                token_to_word += [i] * len(tokens)

332
333
334
335
336
337
338
339
340
341
            # Manage correct placement of the tensors
            with self.device_placement():
                tokens = self.tokenizer.encode_plus(sentence, return_attention_mask=False, return_tensors='tf' if is_tf_available() else 'pt')

                # Forward
                if is_torch_available():
                    with torch.no_grad():
                        entities = self.model(**tokens)[0][0].cpu().numpy()
                else:
                    entities = self.model(tokens)[0][0].numpy()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
342
343
344

            # Normalize scores
            answer, token_start = [], 1
345
            for idx, word in groupby(token_to_word):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
346
347
348
349
350
351

                # Sum log prob over token, then normalize across labels
                score = np.exp(entities[token_start]) / np.exp(entities[token_start]).sum(-1, keepdims=True)
                label_idx = score.argmax()

                answer += [{
352
353
354
                    'word': words[idx],
                    'score': score[label_idx].item(),
                    'entity': self.model.config.id2label[label_idx]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
355
356
357
358
359
360
361
362
363
364
                }]

                # Update token start
                token_start += len(list(word))

            # Append
            answers += [answer]
        return answers


365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
class QuestionAnsweringArgumentHandler(ArgumentHandler):
    """
    QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped
    to internal SquadExample / SquadFeature structures.

    QuestionAnsweringArgumentHandler manages all the possible to create SquadExample from the command-line supplied
    arguments.
    """
    def __call__(self, *args, **kwargs):
        # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
        if args is not None and len(args) > 0:
            if len(args) == 1:
                kwargs['X'] = args[0]
            else:
                kwargs['X'] = list(args)

            # Generic compatibility with sklearn and Keras
            # Batched data
        if 'X' in kwargs or 'data' in kwargs:
            data = kwargs['X'] if 'X' in kwargs else kwargs['data']

            if not isinstance(data, list):
                data = [data]

            for i, item in enumerate(data):
                if isinstance(item, dict):
                    if any(k not in item for k in ['question', 'context']):
                        raise KeyError('You need to provide a dictionary with keys {question:..., context:...}')
                    data[i] = QuestionAnsweringPipeline.create_sample(**item)

                elif isinstance(item, SquadExample):
                    continue
                else:
                    raise ValueError(
                        '{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
                            .format('X' if 'X' in kwargs else 'data')
                    )
            inputs = data

            # Tabular input
        elif 'question' in kwargs and 'context' in kwargs:
            if isinstance(kwargs['question'], str):
                kwargs['question'] = [kwargs['question']]

            if isinstance(kwargs['context'], str):
                kwargs['context'] = [kwargs['context']]

            inputs = [QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs['question'], kwargs['context'])]
        else:
            raise ValueError('Unknown arguments {}'.format(kwargs))

        if not isinstance(inputs, list):
            inputs = [inputs]

        return inputs


Morgan Funtowicz's avatar
Morgan Funtowicz committed
422
423
class QuestionAnsweringPipeline(Pipeline):
    """
424
    Question Answering pipeline using ModelForQuestionAnswering head.
Morgan Funtowicz's avatar
Morgan Funtowicz committed
425
426
427
428
    """

    @staticmethod
    def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]:
429
430
431
432
433
434
435
436
437
        """
        QuestionAnsweringPipeline leverages the SquadExample/SquadFeatures internally.
        This helper method encapsulate all the logic for converting question(s) and context(s) to SquadExample(s).
        We currently support extractive question answering.
        Args:
             question: (str, List[str]) The question to be ask for the associated context
             context: (str, List[str]) The context in which we will look for the answer.
        """
        if isinstance(question, list):
Morgan Funtowicz's avatar
Morgan Funtowicz committed
438
439
440
441
            return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
        else:
            return SquadExample(None, question, context, None, None, None)

442
443
444
    def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer], device: int = -1, **kwargs):
        super().__init__(model, tokenizer, args_parser=QuestionAnsweringArgumentHandler(),
                         device=device, **kwargs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
445
446

    def inputs_for_model(self, features: Union[SquadExample, List[SquadExample]]) -> Dict:
447
448
449
450
451
452
        """
        Generates the input dictionary with model-specific parameters.

        Returns:
            dict holding all the required parameters for model's forward
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        args = ['input_ids', 'attention_mask']
        model_type = type(self.model).__name__.lower()

        if 'distilbert' not in model_type and 'xlm' not in model_type:
            args += ['token_type_ids']

        if 'xlnet' in model_type or 'xlm' in model_type:
            args += ['cls_index', 'p_mask']

        if isinstance(features, SquadExample):
            return {k: features.__dict__[k] for k in args}
        else:
            return {k: [feature.__dict__[k] for feature in features] for k in args}

    def __call__(self, *texts, **kwargs):
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        """
        Args:
            We support multiple use-cases, the following are exclusive:
            X: sequence of SquadExample
            data: sequence of SquadExample
            question: (str, List[str]), batch of question(s) to map along with context
            context: (str, List[str]), batch of context(s) associated with the provided question keyword argument
        Returns:
            dict: {'answer': str, 'score": float, 'start": int, "end": int}
            answer: the textual answer in the intial context
            score: the score the current answer scored for the model
            start: the character index in the original string corresponding to the beginning of the answer' span
            end: the character index in the original string corresponding to the ending of the answer' span
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        # Set defaults values
        kwargs.setdefault('topk', 1)
        kwargs.setdefault('doc_stride', 128)
        kwargs.setdefault('max_answer_len', 15)
        kwargs.setdefault('max_seq_len', 384)
        kwargs.setdefault('max_question_len', 64)

        if kwargs['topk'] < 1:
            raise ValueError('topk parameter should be >= 1 (got {})'.format(kwargs['topk']))

        if kwargs['max_answer_len'] < 1:
            raise ValueError('max_answer_len parameter should be >= 1 (got {})'.format(kwargs['max_answer_len']))

        # Convert inputs to features
496
        examples = self._args_parser(*texts, **kwargs)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
497
498
499
        features = squad_convert_examples_to_features(examples, self.tokenizer, kwargs['max_seq_len'], kwargs['doc_stride'], kwargs['max_question_len'], False)
        fw_args = self.inputs_for_model(features)

500
501
502
503
504
505
506
507
508
509
510
511
512
513
        # Manage tensor allocation on correct device
        with self.device_placement():
            if is_tf_available():
                import tensorflow as tf
                fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
                start, end = self.model(fw_args)
                start, end = start.numpy(), end.numpy()
            else:
                import torch
                with torch.no_grad():
                    # Retrieve the score for the context tokens only (removing question tokens)
                    fw_args = {k: torch.tensor(v) for (k, v) in fw_args.items()}
                    start, end = self.model(**fw_args)
                    start, end = start.cpu().numpy(), end.cpu().numpy()
Morgan Funtowicz's avatar
Morgan Funtowicz committed
514
515
516
517
518
519
520
521
522
523

        answers = []
        for (example, feature, start_, end_) in zip(examples, features, start, end):
            # Normalize logits and spans to retrieve the answer
            start_ = np.exp(start_) / np.sum(np.exp(start_))
            end_ = np.exp(end_) / np.sum(np.exp(end_))

            # Mask padding and question
            start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1)

524
            # TODO : What happens if not possible
Morgan Funtowicz's avatar
Morgan Funtowicz committed
525
526
527
528
529
530
531
            # Mask CLS
            start_[0] = end_[0] = 0

            starts, ends, scores = self.decode(start_, end_, kwargs['topk'], kwargs['max_answer_len'])
            char_to_word = np.array(example.char_to_word_offset)

            # Convert the answer (tokens) back to the original text
532
            answers += [
Morgan Funtowicz's avatar
Morgan Funtowicz committed
533
                {
534
535
536
                    'score': score.item(),
                    'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
                    'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
537
                    'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]:feature.token_to_orig_map[e] + 1])
Morgan Funtowicz's avatar
Morgan Funtowicz committed
538
539
                }
                for s, e, score in zip(starts, ends, scores)
540
            ]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
541

542
543
        if len(answers) == 1:
            return answers[0]
Morgan Funtowicz's avatar
Morgan Funtowicz committed
544
545
546
        return answers

    def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
547
548
549
550
551
552
553
554
555
556
557
558
559
        """
        Take the output of any QuestionAnswering head and will generate probalities for each span to be
        the actual answer.
        In addition, it filters out some unwanted/impossible cases like answer len being greater than
        max_answer_len or answer end position being before the starting position.
        The method supports output the k-best answer through the topk argument.

        Args:
            start: numpy array, holding individual start probabilities for each token
            end: numpy array, holding individual end probabilities for each token
            topk: int, indicates how many possible answer span(s) to extract from the model's output
            max_answer_len: int, maximum size of the answer to extract from the model's output
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        # Ensure we have batch axis
        if start.ndim == 1:
            start = start[None]

        if end.ndim == 1:
            end = end[None]

        # Compute the score of each tuple(start, end) to be the real answer
        outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))

        # Remove candidate with end < start and end - start > max_answer_len
        candidates = np.tril(np.triu(outer), max_answer_len - 1)

        #  Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
        scores_flat = candidates.flatten()
        if topk == 1:
            idx_sort = [np.argmax(scores_flat)]
        elif len(scores_flat) < topk:
            idx_sort = np.argsort(-scores_flat)
        else:
            idx = np.argpartition(-scores_flat, topk)[0:topk]
            idx_sort = idx[np.argsort(-scores_flat[idx])]

        start, end = np.unravel_index(idx_sort, candidates.shape)[1:]
        return start, end, candidates[0, start, end]

    def span_to_answer(self, text: str, start: int, end: int):
587
588
589
590
591
592
593
594
595
596
597
598
        """
        When decoding from token probalities, this method maps token indexes to actual word in
        the initial context.

        Args:
            text: str, the actual context to extract the answer from
            start: int, starting answer token index
            end: int, ending answer token index

        Returns:
            dict: {'answer': str, 'start': int, 'end': int}
        """
Morgan Funtowicz's avatar
Morgan Funtowicz committed
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        words = []
        token_idx = char_start_idx = char_end_idx = chars_idx = 0

        for i, word in enumerate(text.split(" ")):
            token = self.tokenizer.tokenize(word)

            # Append words if they are in the span
            if start <= token_idx <= end:
                if token_idx == start:
                    char_start_idx = chars_idx

                if token_idx == end:
                    char_end_idx = chars_idx + len(word)

                words += [word]

            # Stop if we went over the end of the answer
            if token_idx > end:
                break

            # Append the subtokenization length to the running index
            token_idx += len(token)
            chars_idx += len(word) + 1

        # Join text with spaces
        return {'answer': ' '.join(words), 'start': max(0, char_start_idx), 'end': min(len(text), char_end_idx)}


# Register all the supported task here
SUPPORTED_TASKS = {
629
630
631
632
633
    'feature-extraction': {
      'impl': FeatureExtractionPipeline,
      'tf': TFAutoModel if is_tf_available() else None,
      'pt': AutoModel if is_torch_available() else None,
    },
Morgan Funtowicz's avatar
Morgan Funtowicz committed
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
    'text-classification': {
        'impl': TextClassificationPipeline,
        'tf': TFAutoModelForSequenceClassification if is_tf_available() else None,
        'pt': AutoModelForSequenceClassification if is_torch_available() else None
    },
    'ner': {
      'impl': NerPipeline,
      'tf': TFAutoModelForTokenClassification if is_tf_available() else None,
      'pt': AutoModelForTokenClassification if is_torch_available() else None,
    },
    'question-answering': {
        'impl': QuestionAnsweringPipeline,
        'tf': TFAutoModelForQuestionAnswering if is_tf_available() else None,
        'pt': AutoModelForQuestionAnswering if is_torch_available() else None
    }
}


652
653
def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] = None,
             tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, **kwargs) -> Pipeline:
Morgan Funtowicz's avatar
Morgan Funtowicz committed
654
    """
655
656
657
658
659
    Utility factory method to build a pipeline.
    Pipeline are made of:
        A Tokenizer instance in charge of mapping raw textual input to token
        A Model instance
        Some (optional) post processing for enhancing model's output
Morgan Funtowicz's avatar
Morgan Funtowicz committed
660
661
    """
    # Try to infer tokenizer from model name (if provided as str)
662
    if tokenizer is None:
663
664
665
666
667
        if not isinstance(model, str):
            # Impossible to guest what is the right tokenizer here
            raise Exception('Tokenizer cannot be None if provided model is a PreTrainedModel instance')
        else:
            tokenizer = model
Morgan Funtowicz's avatar
Morgan Funtowicz committed
668
669
670
671
672
673
674
675
676

    tokenizer = tokenizer if isinstance(tokenizer, PreTrainedTokenizer) else AutoTokenizer.from_pretrained(tokenizer)

    if task not in SUPPORTED_TASKS:
        raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))

    targeted_task = SUPPORTED_TASKS[task]
    task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt']

677
    # Special handling for model conversion
678
679
680
681
682
683
684
685
686
687
688
689
    if isinstance(model, str):
        from_tf = model.endswith('.h5') and not is_tf_available()
        from_pt = model.endswith('.bin') and not is_torch_available()

        if from_tf:
            logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. '
                           'Trying to load the model with PyTorch.')
        elif from_pt:
            logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. '
                           'Trying to load the model with Tensorflow.')
    else:
        from_tf = from_pt = False
690

691
692
    if isinstance(config, str):
        config = PretrainedConfig.from_pretrained(config)
693
694
695
696
697

    if allocator.__name__.startswith('TF'):
        model = allocator.from_pretrained(model, config=config, from_pt=from_pt)
    else:
        model = allocator.from_pretrained(model, config=config, from_tf=from_tf)
Morgan Funtowicz's avatar
Morgan Funtowicz committed
698
    return task(model, tokenizer, **kwargs)