Commit de276de1 authored by LysandreJik's avatar LysandreJik
Browse files

Working evaluation

parent c835bc85
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet).""" """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
from transformers.data.metrics.squad_metrics import compute_predictions, compute_predictions_extended, squad_evaluate
import argparse import argparse
import logging import logging
...@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval() model.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {
'attention_mask': batch[1] 'input_ids': batch[0],
} 'attention_mask': batch[1]
}
if args.model_type != 'distilbert': if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3] example_indices = batch[3]
...@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""):
for i, example_index in enumerate(example_indices): for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()] eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id) unique_id = int(eval_feature.unique_id)
if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id])
result = RawResultExtended(unique_id = unique_id,
start_top_log_probs = to_list(outputs[0][i]),
start_top_index = to_list(outputs[1][i]),
end_top_log_probs = to_list(outputs[2][i]),
end_top_index = to_list(outputs[3][i]),
cls_logits = to_list(outputs[4][i]))
else:
result = RawResult(unique_id = unique_id,
start_logits = to_list(outputs[0][i]),
end_logits = to_list(outputs[1][i]))
all_results.append(result) all_results.append(result)
evalTime = timeit.default_timer() - start_time evalTime = timeit.default_timer() - start_time
...@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
write_predictions_extended(examples, features, all_results, args.n_best_size, predictions = compute_predictions_extended(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file, args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file, output_nbest_file, output_null_log_odds_file, args.predict_file,
model.config.start_n_top, model.config.end_n_top, model.config.start_n_top, model.config.end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging) args.version_2_with_negative, tokenizer, args.verbose_logging)
else: else:
write_predictions(examples, features, all_results, args.n_best_size, predictions = compute_predictions(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file, args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold) args.version_2_with_negative, args.null_score_diff_threshold)
# Evaluate with the official SQuAD script results = squad_evaluate(examples, predictions)
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
pred_file=output_prediction_file,
na_prob_file=output_null_log_odds_file)
results = evaluate_on_squad(evaluate_options)
return results return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
...@@ -306,8 +295,12 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -306,8 +295,12 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
logger.info("Creating features from dataset file at %s", input_file) logger.info("Creating features from dataset file at %s", input_file)
processor = SquadV2Processor() processor = SquadV2Processor()
examples = processor.get_dev_examples("examples/squad") if evaluate else processor.get_train_examples("examples/squad") examples = processor.get_dev_examples("examples/squad", only_first=100) if evaluate else processor.get_train_examples("examples/squad")
features = squad_convert_examples_to_features( # import tensorflow_datasets as tfds
# tfds_examples = tfds.load("squad")
# examples = SquadV1Processor().get_examples_from_dataset(tfds_examples["validation"])
features = squad_convert_examples_to_features(
examples=examples, examples=examples,
tokenizer=tokenizer, tokenizer=tokenizer,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
......
This diff is collapsed.
...@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor): ...@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor):
else: else:
is_impossible = False is_impossible = False
if not is_impossible and is_training: if not is_impossible:
if (len(qa["answers"]) != 1): if is_training:
raise ValueError( answer = qa["answers"][0]
"For training, each question should have exactly 1 answer.") answer_text = answer['text']
answer = qa["answers"][0] start_position_character = answer['answer_start']
answer_text = answer['text'] else:
start_position_character = answer['answer_start'] answers = qa["answers"]
example = SquadExample( example = SquadExample(
qas_id=qas_id, qas_id=qas_id,
...@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor): ...@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor):
answer_text=answer_text, answer_text=answer_text,
start_position_character=start_position_character, start_position_character=start_position_character,
title=title, title=title,
is_impossible=is_impossible is_impossible=is_impossible,
answers=answers
) )
examples.append(example) examples.append(example)
...@@ -352,6 +353,7 @@ class SquadExample(object): ...@@ -352,6 +353,7 @@ class SquadExample(object):
answer_text, answer_text,
start_position_character, start_position_character,
title, title,
answers=None,
is_impossible=False): is_impossible=False):
self.qas_id = qas_id self.qas_id = qas_id
self.question_text = question_text self.question_text = question_text
...@@ -359,6 +361,7 @@ class SquadExample(object): ...@@ -359,6 +361,7 @@ class SquadExample(object):
self.answer_text = answer_text self.answer_text = answer_text
self.title = title self.title = title
self.is_impossible = is_impossible self.is_impossible = is_impossible
self.answers = answers
self.start_position, self.end_position = 0, 0 self.start_position, self.end_position = 0, 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment