squad.py 27.5 KB
Newer Older
Aymeric Augustin's avatar
Aymeric Augustin committed
1
import json
Lysandre's avatar
Lysandre committed
2
3
import logging
import os
erenup's avatar
erenup committed
4
from functools import partial
Aymeric Augustin's avatar
Aymeric Augustin committed
5
6
7
8
from multiprocessing import Pool, cpu_count

import numpy as np
from tqdm import tqdm
Lysandre's avatar
Lysandre committed
9

Aymeric Augustin's avatar
Aymeric Augustin committed
10
from ...file_utils import is_tf_available, is_torch_available
11
12
from ...tokenization_bert import whitespace_tokenize
from .utils import DataProcessor
Aymeric Augustin's avatar
Aymeric Augustin committed
13

LysandreJik's avatar
LysandreJik committed
14

LysandreJik's avatar
Cleanup  
LysandreJik committed
15
if is_torch_available():
LysandreJik's avatar
LysandreJik committed
16
17
    import torch
    from torch.utils.data import TensorDataset
Lysandre's avatar
Lysandre committed
18
19
20
21
22
23

if is_tf_available():
    import tensorflow as tf

logger = logging.getLogger(__name__)

Lysandre's avatar
Lysandre committed
24
25

def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
LysandreJik's avatar
LysandreJik committed
26
27
28
29
30
    """Returns tokenized answer spans that better match the annotated answer."""
    tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))

    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
Lysandre's avatar
Lysandre committed
31
            text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
LysandreJik's avatar
LysandreJik committed
32
33
34
35
36
            if text_span == tok_answer_text:
                return (new_start, new_end)

    return (input_start, input_end)

Lysandre's avatar
Lysandre committed
37

LysandreJik's avatar
LysandreJik committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def _check_is_max_context(doc_spans, cur_span_index, position):
    """Check if this is the 'max context' doc span for the token."""
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index

Lysandre's avatar
Lysandre committed
57

LysandreJik's avatar
LysandreJik committed
58
59
60
def _new_check_is_max_context(doc_spans, cur_span_index, position):
    """Check if this is the 'max context' doc span for the token."""
    # if len(doc_spans) == 1:
Lysandre's avatar
Lysandre committed
61
    # return True
LysandreJik's avatar
LysandreJik committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span["start"] + doc_span["length"] - 1
        if position < doc_span["start"]:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span["start"]
        num_right_context = end - position
        score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index

Lysandre's avatar
Lysandre committed
79

LysandreJik's avatar
LysandreJik committed
80
81
82
83
def _is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
        return True
    return False
Lysandre's avatar
wip  
Lysandre committed
84

85
86

def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training):
erenup's avatar
erenup committed
87
88
89
90
91
92
93
    features = []
    if is_training and not example.is_impossible:
        # Get start and end position
        start_position = example.start_position
        end_position = example.end_position

        # If the answer cannot be found in the text, then skip this example.
94
        actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
erenup's avatar
erenup committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
        if actual_text.find(cleaned_answer_text) == -1:
            logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
            return []

    tok_to_orig_index = []
    orig_to_tok_index = []
    all_doc_tokens = []
    for (i, token) in enumerate(example.doc_tokens):
        orig_to_tok_index.append(len(all_doc_tokens))
        sub_tokens = tokenizer.tokenize(token)
        for sub_token in sub_tokens:
            tok_to_orig_index.append(i)
            all_doc_tokens.append(sub_token)

    if is_training and not example.is_impossible:
        tok_start_position = orig_to_tok_index[example.start_position]
        if example.end_position < len(example.doc_tokens) - 1:
            tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
        else:
            tok_end_position = len(all_doc_tokens) - 1

        (tok_start_position, tok_end_position) = _improve_answer_span(
            all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
        )

    spans = []

    truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
124
125
    sequence_added_tokens = (
        tokenizer.max_len - tokenizer.max_len_single_sentence + 1
126
        if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer))
127
128
        else tokenizer.max_len - tokenizer.max_len_single_sentence
    )
erenup's avatar
erenup committed
129
130
131
132
133
134
135
136
137
138
139
140
    sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair

    span_doc_tokens = all_doc_tokens
    while len(spans) * doc_stride < len(all_doc_tokens):

        encoded_dict = tokenizer.encode_plus(
            truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
            span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
            max_length=max_seq_length,
            return_overflowing_tokens=True,
            pad_to_max_length=True,
            stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
141
            truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
142
            return_token_type_ids=True,
erenup's avatar
erenup committed
143
        )
Lysandre's avatar
Lysandre committed
144

145
146
147
148
        paragraph_len = min(
            len(all_doc_tokens) - len(spans) * doc_stride,
            max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
        )
erenup's avatar
erenup committed
149

150
        if tokenizer.pad_token_id in encoded_dict["input_ids"]:
151
152
153
154
155
156
157
158
            if tokenizer.padding_side == "right":
                non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
            else:
                last_padding_id_position = (
                    len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
                )
                non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]

erenup's avatar
erenup committed
159
        else:
160
            non_padded_ids = encoded_dict["input_ids"]
erenup's avatar
erenup committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

        tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)

        token_to_orig_map = {}
        for i in range(paragraph_len):
            index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
            token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]

        encoded_dict["paragraph_len"] = paragraph_len
        encoded_dict["tokens"] = tokens
        encoded_dict["token_to_orig_map"] = token_to_orig_map
        encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
        encoded_dict["token_is_max_context"] = {}
        encoded_dict["start"] = len(spans) * doc_stride
        encoded_dict["length"] = paragraph_len

        spans.append(encoded_dict)

        if "overflowing_tokens" not in encoded_dict:
            break
        span_doc_tokens = encoded_dict["overflowing_tokens"]

    for doc_span_index in range(len(spans)):
        for j in range(spans[doc_span_index]["paragraph_len"]):
            is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
186
187
188
189
190
            index = (
                j
                if tokenizer.padding_side == "left"
                else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
            )
erenup's avatar
erenup committed
191
192
193
194
            spans[doc_span_index]["token_is_max_context"][index] = is_max_context

    for span in spans:
        # Identify the position of the CLS token
195
        cls_index = span["input_ids"].index(tokenizer.cls_token_id)
erenup's avatar
erenup committed
196
197
198

        # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
        # Original TF implem also keep the classification token (set to 0) (not sure why...)
199
        p_mask = np.array(span["token_type_ids"])
erenup's avatar
erenup committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

        p_mask = np.minimum(p_mask, 1)

        if tokenizer.padding_side == "right":
            # Limit positive values to one
            p_mask = 1 - p_mask

        p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1

        # Set the CLS index to '0'
        p_mask[cls_index] = 0

        span_is_impossible = example.is_impossible
        start_position = 0
        end_position = 0
        if is_training and not span_is_impossible:
            # For training, if our document chunk does not contain an annotation
            # we throw it out, since there is nothing to predict.
            doc_start = span["start"]
            doc_end = span["start"] + span["length"] - 1
            out_of_span = False

            if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
                out_of_span = True

            if out_of_span:
                start_position = cls_index
                end_position = cls_index
                span_is_impossible = True
            else:
                if tokenizer.padding_side == "left":
                    doc_offset = 0
                else:
                    doc_offset = len(truncated_query) + sequence_added_tokens

                start_position = tok_start_position - doc_start + doc_offset
                end_position = tok_end_position - doc_start + doc_offset

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        features.append(
            SquadFeatures(
                span["input_ids"],
                span["attention_mask"],
                span["token_type_ids"],
                cls_index,
                p_mask.tolist(),
                example_index=0,  # Can not set unique_id and example_index here. They will be set after multiple processing.
                unique_id=0,
                paragraph_len=span["paragraph_len"],
                token_is_max_context=span["token_is_max_context"],
                tokens=span["tokens"],
                token_to_orig_map=span["token_to_orig_map"],
                start_position=start_position,
                end_position=end_position,
Lysandre's avatar
Lysandre committed
253
                is_impossible=span_is_impossible,
254
255
            )
        )
erenup's avatar
erenup committed
256
257
    return features

258

erenup's avatar
erenup committed
259
260
261
262
def squad_convert_example_to_features_init(tokenizer_for_convert):
    global tokenizer
    tokenizer = tokenizer_for_convert

263
264
265
266

def squad_convert_examples_to_features(
    examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False, threads=1
):
LysandreJik's avatar
LysandreJik committed
267
268
269
270
271
272
273
274
275
276
    """
    Converts a list of examples into a list of features that can be directly given as input to a model.
    It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.

    Args:
        examples: list of :class:`~transformers.data.processors.squad.SquadExample`
        tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer`
        max_seq_length: The maximum sequence length of the inputs.
        doc_stride: The stride used when the context is too large and is split across several features.
        max_query_length: The maximum length of the query.
LysandreJik's avatar
LysandreJik committed
277
278
279
280
        is_training: whether to create features for model evaluation or model training.
        return_dataset: Default False. Either 'pt' or 'tf'.
            if 'pt': returns a torch.data.TensorDataset,
            if 'tf': returns a tf.data.Dataset
erenup's avatar
erenup committed
281
282
        threads: multiple processing threadsa-smi

LysandreJik's avatar
LysandreJik committed
283
284
285
286
287
288
289
290
291

    Returns:
        list of :class:`~transformers.data.processors.squad.SquadFeatures`

    Example::

        processor = SquadV2Processor()
        examples = processor.get_dev_examples(data_dir)

292
        features = squad_convert_examples_to_features(
LysandreJik's avatar
LysandreJik committed
293
294
295
296
297
298
299
300
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=not evaluate,
        )
    """
Lysandre's avatar
Lysandre committed
301

302
    # Defining helper methods
Lysandre's avatar
Lysandre committed
303
    features = []
erenup's avatar
erenup committed
304
305
    threads = min(threads, cpu_count())
    with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        annotate_ = partial(
            squad_convert_example_to_features,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_training=is_training,
        )
        features = list(
            tqdm(
                p.imap(annotate_, examples, chunksize=32),
                total=len(examples),
                desc="convert squad examples to features",
            )
        )
erenup's avatar
erenup committed
320
321
322
    new_features = []
    unique_id = 1000000000
    example_index = 0
323
    for example_features in tqdm(features, total=len(features), desc="add example index and unique id"):
erenup's avatar
erenup committed
324
325
326
327
328
329
        if not example_features:
            continue
        for example_feature in example_features:
            example_feature.example_index = example_index
            example_feature.unique_id = unique_id
            new_features.append(example_feature)
LysandreJik's avatar
LysandreJik committed
330
            unique_id += 1
erenup's avatar
erenup committed
331
332
333
        example_index += 1
    features = new_features
    del new_features
334
    if return_dataset == "pt":
LysandreJik's avatar
LysandreJik committed
335
        if not is_torch_available():
Aymeric Augustin's avatar
Aymeric Augustin committed
336
            raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
LysandreJik's avatar
LysandreJik committed
337
338
339

        # Convert to Tensors and build dataset
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
Lysandre's avatar
Lysandre committed
340
341
        all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
LysandreJik's avatar
LysandreJik committed
342
343
        all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
        all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
Lysandre's avatar
Lysandre committed
344
        all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
LysandreJik's avatar
LysandreJik committed
345
346
347

        if not is_training:
            all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
Lysandre's avatar
Lysandre committed
348
349
350
            dataset = TensorDataset(
                all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask
            )
LysandreJik's avatar
LysandreJik committed
351
352
353
        else:
            all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
            all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
Lysandre's avatar
Lysandre committed
354
355
356
357
358
359
360
361
            dataset = TensorDataset(
                all_input_ids,
                all_attention_masks,
                all_token_type_ids,
                all_start_positions,
                all_end_positions,
                all_cls_index,
                all_p_mask,
Lysandre's avatar
Lysandre committed
362
                all_is_impossible,
Lysandre's avatar
Lysandre committed
363
            )
LysandreJik's avatar
LysandreJik committed
364
365

        return features, dataset
Lysandre's avatar
Lysandre committed
366
367
    elif return_dataset == "tf":
        if not is_tf_available():
Aymeric Augustin's avatar
Aymeric Augustin committed
368
            raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
Lysandre's avatar
Lysandre committed
369
370
371
372
373
374
375
376

        def gen():
            for ex in features:
                yield (
                    {
                        "input_ids": ex.input_ids,
                        "attention_mask": ex.attention_mask,
                        "token_type_ids": ex.token_type_ids,
377
378
                    },
                    {
Lysandre's avatar
Lysandre committed
379
380
381
382
                        "start_position": ex.start_position,
                        "end_position": ex.end_position,
                        "cls_index": ex.cls_index,
                        "p_mask": ex.p_mask,
Lysandre's avatar
Lysandre committed
383
                        "is_impossible": ex.is_impossible,
384
                    },
Lysandre's avatar
Lysandre committed
385
386
387
388
389
390
                )

        return tf.data.Dataset.from_generator(
            gen,
            (
                {"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32},
Lysandre's avatar
Lysandre committed
391
392
393
394
395
396
397
                {
                    "start_position": tf.int64,
                    "end_position": tf.int64,
                    "cls_index": tf.int64,
                    "p_mask": tf.int32,
                    "is_impossible": tf.int32,
                },
Lysandre's avatar
Lysandre committed
398
399
400
401
402
403
404
405
406
407
408
409
            ),
            (
                {
                    "input_ids": tf.TensorShape([None]),
                    "attention_mask": tf.TensorShape([None]),
                    "token_type_ids": tf.TensorShape([None]),
                },
                {
                    "start_position": tf.TensorShape([]),
                    "end_position": tf.TensorShape([]),
                    "cls_index": tf.TensorShape([]),
                    "p_mask": tf.TensorShape([None]),
Lysandre's avatar
Lysandre committed
410
                    "is_impossible": tf.TensorShape([]),
Lysandre's avatar
Lysandre committed
411
412
413
                },
            ),
        )
LysandreJik's avatar
LysandreJik committed
414

Lysandre's avatar
Lysandre committed
415
    return features
Lysandre's avatar
Lysandre committed
416

Lysandre's avatar
wip  
Lysandre committed
417

Lysandre's avatar
Lysandre committed
418
class SquadProcessor(DataProcessor):
LysandreJik's avatar
LysandreJik committed
419
420
421
422
    """
    Processor for the SQuAD data set.
    Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively.
    """
Lysandre's avatar
Lysandre committed
423

Lysandre's avatar
Lysandre committed
424
425
    train_file = None
    dev_file = None
LysandreJik's avatar
LysandreJik committed
426

LysandreJik's avatar
LysandreJik committed
427
    def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
428
        if not evaluate:
Lysandre's avatar
Lysandre committed
429
430
            answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
            answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
431
            answers = []
432
        else:
Lysandre's avatar
Lysandre committed
433
434
435
436
            answers = [
                {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
                for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
            ]
437
438
439
440

            answer = None
            answer_start = None

Lysandre's avatar
Lysandre committed
441
        return SquadExample(
Lysandre's avatar
Lysandre committed
442
443
444
            qas_id=tensor_dict["id"].numpy().decode("utf-8"),
            question_text=tensor_dict["question"].numpy().decode("utf-8"),
            context_text=tensor_dict["context"].numpy().decode("utf-8"),
445
446
            answer_text=answer,
            start_position_character=answer_start,
Lysandre's avatar
Lysandre committed
447
448
            title=tensor_dict["title"].numpy().decode("utf-8"),
            answers=answers,
LysandreJik's avatar
LysandreJik committed
449
450
        )

451
    def get_examples_from_dataset(self, dataset, evaluate=False):
LysandreJik's avatar
Cleanup  
LysandreJik committed
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        """
        Creates a list of :class:`~transformers.data.processors.squad.SquadExample` using a TFDS dataset.

        Args:
            dataset: The tfds dataset loaded from `tensorflow_datasets.load("squad")`
            evaluate: boolean specifying if in evaluation mode or in training mode

        Returns:
            List of SquadExample

        Examples::

            import tensorflow_datasets as tfds
            dataset = tfds.load("squad")

            training_examples = get_examples_from_dataset(dataset, evaluate=False)
            evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
        """

        if evaluate:
            dataset = dataset["validation"]
        else:
            dataset = dataset["train"]
Lysandre's avatar
Lysandre committed
475
476
477

        examples = []
        for tensor_dict in tqdm(dataset):
Lysandre's avatar
Lysandre committed
478
            examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
Lysandre's avatar
Lysandre committed
479
480
481

        return examples

LysandreJik's avatar
LysandreJik committed
482
483
484
485
486
487
488
489
490
491
    def get_train_examples(self, data_dir, filename=None):
        """
        Returns the training examples from the data directory.

        Args:
            data_dir: Directory containing the data files used for training and evaluating.
            filename: None by default, specify this if the training file has a different name than the original one
                which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.

        """
492
493
494
        if data_dir is None:
            data_dir = ""

Lysandre's avatar
Lysandre committed
495
496
497
        if self.train_file is None:
            raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")

Lysandre's avatar
Lysandre committed
498
499
500
        with open(
            os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
        ) as reader:
LysandreJik's avatar
LysandreJik committed
501
            input_data = json.load(reader)["data"]
LysandreJik's avatar
LysandreJik committed
502
        return self._create_examples(input_data, "train")
LysandreJik's avatar
LysandreJik committed
503

LysandreJik's avatar
LysandreJik committed
504
505
506
507
508
509
510
511
512
    def get_dev_examples(self, data_dir, filename=None):
        """
        Returns the evaluation example from the data directory.

        Args:
            data_dir: Directory containing the data files used for training and evaluating.
            filename: None by default, specify this if the evaluation file has a different name than the original one
                which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
        """
513
514
515
        if data_dir is None:
            data_dir = ""

Lysandre's avatar
Lysandre committed
516
517
        if self.dev_file is None:
            raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
Lysandre's avatar
Lysandre committed
518
519
520
521

        with open(
            os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
        ) as reader:
LysandreJik's avatar
LysandreJik committed
522
            input_data = json.load(reader)["data"]
LysandreJik's avatar
LysandreJik committed
523
        return self._create_examples(input_data, "dev")
LysandreJik's avatar
LysandreJik committed
524

LysandreJik's avatar
LysandreJik committed
525
    def _create_examples(self, input_data, set_type):
LysandreJik's avatar
LysandreJik committed
526
527
        is_training = set_type == "train"
        examples = []
528
        for entry in tqdm(input_data):
Lysandre's avatar
Lysandre committed
529
            title = entry["title"]
LysandreJik's avatar
LysandreJik committed
530
531
532
533
534
            for paragraph in entry["paragraphs"]:
                context_text = paragraph["context"]
                for qa in paragraph["qas"]:
                    qas_id = qa["id"]
                    question_text = qa["question"]
535
                    start_position_character = None
LysandreJik's avatar
LysandreJik committed
536
                    answer_text = None
537
                    answers = []
Lysandre's avatar
Lysandre committed
538

Lysandre's avatar
Lysandre committed
539
540
541
542
543
                    if "is_impossible" in qa:
                        is_impossible = qa["is_impossible"]
                    else:
                        is_impossible = False

LysandreJik's avatar
LysandreJik committed
544
545
546
                    if not is_impossible:
                        if is_training:
                            answer = qa["answers"][0]
Lysandre's avatar
Lysandre committed
547
548
                            answer_text = answer["text"]
                            start_position_character = answer["answer_start"]
LysandreJik's avatar
LysandreJik committed
549
550
                        else:
                            answers = qa["answers"]
LysandreJik's avatar
LysandreJik committed
551

Lysandre's avatar
Lysandre committed
552
                    example = SquadExample(
LysandreJik's avatar
LysandreJik committed
553
554
555
556
                        qas_id=qas_id,
                        question_text=question_text,
                        context_text=context_text,
                        answer_text=answer_text,
557
                        start_position_character=start_position_character,
Lysandre's avatar
Lysandre committed
558
                        title=title,
LysandreJik's avatar
LysandreJik committed
559
                        is_impossible=is_impossible,
Lysandre's avatar
Lysandre committed
560
                        answers=answers,
LysandreJik's avatar
LysandreJik committed
561
                    )
Lysandre's avatar
Lysandre committed
562

LysandreJik's avatar
LysandreJik committed
563
564
565
                    examples.append(example)
        return examples

Lysandre's avatar
Lysandre committed
566

Lysandre's avatar
Lysandre committed
567
568
569
570
571
572
573
574
class SquadV1Processor(SquadProcessor):
    train_file = "train-v1.1.json"
    dev_file = "dev-v1.1.json"


class SquadV2Processor(SquadProcessor):
    train_file = "train-v2.0.json"
    dev_file = "dev-v2.0.json"
Lysandre's avatar
Lysandre committed
575

LysandreJik's avatar
LysandreJik committed
576

Lysandre's avatar
Lysandre committed
577
class SquadExample(object):
LysandreJik's avatar
LysandreJik committed
578
579
    """
    A single training/test example for the Squad dataset, as loaded from disk.
LysandreJik's avatar
LysandreJik committed
580
581
582
583
584
585
586
587
588
589

    Args:
        qas_id: The example's unique identifier
        question_text: The question string
        context_text: The context string
        answer_text: The answer string
        start_position_character: The character position of the start of the answer
        title: The title of the example
        answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
        is_impossible: False by default, set to True if the example has no possible answer.
LysandreJik's avatar
LysandreJik committed
590
591
    """

Lysandre's avatar
Lysandre committed
592
593
594
595
596
597
598
599
600
601
602
    def __init__(
        self,
        qas_id,
        question_text,
        context_text,
        answer_text,
        start_position_character,
        title,
        answers=[],
        is_impossible=False,
    ):
LysandreJik's avatar
LysandreJik committed
603
604
605
606
607
        self.qas_id = qas_id
        self.question_text = question_text
        self.context_text = context_text
        self.answer_text = answer_text
        self.title = title
Lysandre's avatar
Lysandre committed
608
        self.is_impossible = is_impossible
LysandreJik's avatar
LysandreJik committed
609
        self.answers = answers
Lysandre's avatar
Lysandre committed
610
611

        self.start_position, self.end_position = 0, 0
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631

        doc_tokens = []
        char_to_word_offset = []
        prev_is_whitespace = True

        # Split on whitespace so that different tokens may be attributed to their original position.
        for c in self.context_text:
            if _is_whitespace(c):
                prev_is_whitespace = True
            else:
                if prev_is_whitespace:
                    doc_tokens.append(c)
                else:
                    doc_tokens[-1] += c
                prev_is_whitespace = False
            char_to_word_offset.append(len(doc_tokens) - 1)

        self.doc_tokens = doc_tokens
        self.char_to_word_offset = char_to_word_offset

632
        # Start and end positions only has a value during evaluation.
Lysandre's avatar
Lysandre committed
633
        if start_position_character is not None and not is_impossible:
634
            self.start_position = char_to_word_offset[start_position_character]
635
636
637
            self.end_position = char_to_word_offset[
                min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
            ]
LysandreJik's avatar
LysandreJik committed
638
639


Lysandre's avatar
Lysandre committed
640
class SquadFeatures(object):
LysandreJik's avatar
LysandreJik committed
641
642
    """
    Single squad example features to be fed to a model.
LysandreJik's avatar
LysandreJik committed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
    Those features are model-specific and can be crafted from :class:`~transformers.data.processors.squad.SquadExample`
    using the :method:`~transformers.data.processors.squad.squad_convert_examples_to_features` method.

    Args:
        input_ids: Indices of input sequence tokens in the vocabulary.
        attention_mask: Mask to avoid performing attention on padding token indices.
        token_type_ids: Segment token indices to indicate first and second portions of the inputs.
        cls_index: the index of the CLS token.
        p_mask: Mask identifying tokens that can be answers vs. tokens that cannot.
            Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer
        example_index: the index of the example
        unique_id: The unique Feature identifier
        paragraph_len: The length of the context
        token_is_max_context: List of booleans identifying which tokens have their maximum context in this feature object.
            If a token does not have their maximum context in this feature object, it means that another feature object
            has more information related to that token and should be prioritized over this feature for that token.
        tokens: list of tokens corresponding to the input ids
        token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
661
662
        start_position: start of the answer token index
        end_position: end of the answer token index
LysandreJik's avatar
LysandreJik committed
663
664
    """

Lysandre's avatar
Lysandre committed
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    def __init__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        cls_index,
        p_mask,
        example_index,
        unique_id,
        paragraph_len,
        token_is_max_context,
        tokens,
        token_to_orig_map,
        start_position,
        end_position,
Lysandre's avatar
Lysandre committed
680
        is_impossible,
Lysandre's avatar
Lysandre committed
681
682
    ):
        self.input_ids = input_ids
LysandreJik's avatar
LysandreJik committed
683
684
685
686
687
688
689
690
691
692
693
694
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.cls_index = cls_index
        self.p_mask = p_mask

        self.example_index = example_index
        self.unique_id = unique_id
        self.paragraph_len = paragraph_len
        self.token_is_max_context = token_is_max_context
        self.tokens = tokens
        self.token_to_orig_map = token_to_orig_map

Lysandre's avatar
Lysandre committed
695
696
        self.start_position = start_position
        self.end_position = end_position
Lysandre's avatar
Lysandre committed
697
        self.is_impossible = is_impossible
LysandreJik's avatar
LysandreJik committed
698

Lysandre's avatar
Lysandre committed
699

LysandreJik's avatar
LysandreJik committed
700
701
702
703
704
class SquadResult(object):
    """
    Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.

    Args:
LysandreJik's avatar
LysandreJik committed
705
706
707
        unique_id: The unique identifier corresponding to that example.
        start_logits: The logits corresponding to the start of the answer
        end_logits: The logits corresponding to the end of the answer
LysandreJik's avatar
LysandreJik committed
708
    """
Lysandre's avatar
Lysandre committed
709

LysandreJik's avatar
LysandreJik committed
710
    def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
LysandreJik's avatar
Cleanup  
LysandreJik committed
711
712
        self.start_logits = start_logits
        self.end_logits = end_logits
LysandreJik's avatar
LysandreJik committed
713
        self.unique_id = unique_id
Lysandre's avatar
Lysandre committed
714

LysandreJik's avatar
LysandreJik committed
715
716
717
        if start_top_index:
            self.start_top_index = start_top_index
            self.end_top_index = end_top_index
Lysandre's avatar
Lysandre committed
718
            self.cls_logits = cls_logits