Commit bd41e829 authored by Lysandre's avatar Lysandre
Browse files

Cleanup & Evaluation now works

parent 0669c1fc
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet).""" """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor
import argparse import argparse
import logging import logging
...@@ -45,9 +45,9 @@ from transformers import (WEIGHTS_NAME, BertConfig, ...@@ -45,9 +45,9 @@ from transformers import (WEIGHTS_NAME, BertConfig,
XLNetTokenizer, XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features, read_squad_examples as sread_squad_examples from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
from utils_squad import (RawResult, write_predictions, from utils_squad import (convert_examples_to_features as old_convert, read_squad_examples as old_read, RawResult, write_predictions,
RawResultExtended, write_predictions_extended) RawResultExtended, write_predictions_extended)
# The follwing import is the official SQuAD evaluation script (2.0). # The follwing import is the official SQuAD evaluation script (2.0).
...@@ -304,28 +304,20 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -304,28 +304,20 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
logger.info("Creating features from dataset file at %s", input_file) logger.info("Creating features from dataset file at %s", input_file)
examples = read_squad_examples(input_file=input_file,
is_training=not evaluate, processor = SquadV2Processor()
version_2_with_negative=args.version_2_with_negative) examples = processor.get_dev_examples("examples/squad") if evaluate else processor.get_train_examples("examples/squad")
keep_n_examples = 1000 features = squad_convert_examples_to_features(
processor = SquadV1Processor() examples=examples,
values = processor.get_dev_examples("examples/squad") tokenizer=tokenizer,
examples = values[:keep_n_examples] max_seq_length=args.max_seq_length,
features = squad_convert_examples_to_features(examples=exampless, doc_stride=args.doc_stride,
tokenizer=tokenizer, max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length, is_training=not evaluate,
doc_stride=args.doc_stride, sequence_a_is_doc=True if args.model_type in ['xlnet'] else False
max_query_length=args.max_query_length, )
is_training=not evaluate,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
print("DONE")
import sys
sys.exit()
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
...@@ -335,8 +327,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -335,8 +327,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
# Convert to Tensors and build dataset # Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if evaluate: if evaluate:
......
...@@ -74,26 +74,16 @@ def _is_whitespace(c): ...@@ -74,26 +74,16 @@ def _is_whitespace(c):
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_stride, max_query_length, is_training, doc_stride, max_query_length, is_training,
cls_token_at_end=True,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=0, pad_token_segment_id=0,
mask_padding_with_zero=True,
sequence_a_is_doc=False): sequence_a_is_doc=False):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
# Defining helper methods # Defining helper methods
unique_id = 1000000000 unique_id = 1000000000
features = [] features = []
new_features = []
for (example_index, example) in enumerate(tqdm(examples)): for (example_index, example) in enumerate(tqdm(examples)):
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
# Get start and end position # Get start and end position
answer_length = len(example.answer_text)
start_position = example.start_position start_position = example.start_position
end_position = example.end_position end_position = example.end_position
...@@ -227,7 +217,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -227,7 +217,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
new_features.append(NewSquadFeatures( features.append(NewSquadFeatures(
span['input_ids'], span['input_ids'],
span['attention_mask'], span['attention_mask'],
span['token_type_ids'], span['token_type_ids'],
...@@ -247,7 +237,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -247,7 +237,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
unique_id += 1 unique_id += 1
return new_features return features
class SquadProcessor(DataProcessor): class SquadProcessor(DataProcessor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment