"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "a32ab5afbd286ef3640b5324ec7056304838804b"
Commit de276de1 authored by LysandreJik's avatar LysandreJik
Browse files

Working evaluation

parent c835bc85
......@@ -16,7 +16,8 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
from transformers.data.metrics.squad_metrics import compute_predictions, compute_predictions_extended, squad_evaluate
import argparse
import logging
......@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1]
}
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1]
}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3]
......@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""):
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure
result = RawResultExtended(unique_id = unique_id,
start_top_log_probs = to_list(outputs[0][i]),
start_top_index = to_list(outputs[1][i]),
end_top_log_probs = to_list(outputs[2][i]),
end_top_index = to_list(outputs[3][i]),
cls_logits = to_list(outputs[4][i]))
else:
result = RawResult(unique_id = unique_id,
start_logits = to_list(outputs[0][i]),
end_logits = to_list(outputs[1][i]))
result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id])
all_results.append(result)
evalTime = timeit.default_timer() - start_time
......@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure
write_predictions_extended(examples, features, all_results, args.n_best_size,
predictions = compute_predictions_extended(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file,
model.config.start_n_top, model.config.end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging)
else:
write_predictions(examples, features, all_results, args.n_best_size,
predictions = compute_predictions(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)
# Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
pred_file=output_prediction_file,
na_prob_file=output_null_log_odds_file)
results = evaluate_on_squad(evaluate_options)
results = squad_evaluate(examples, predictions)
return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
......@@ -306,8 +295,12 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
logger.info("Creating features from dataset file at %s", input_file)
processor = SquadV2Processor()
examples = processor.get_dev_examples("examples/squad") if evaluate else processor.get_train_examples("examples/squad")
features = squad_convert_examples_to_features(
examples = processor.get_dev_examples("examples/squad", only_first=100) if evaluate else processor.get_train_examples("examples/squad")
# import tensorflow_datasets as tfds
# tfds_examples = tfds.load("squad")
# examples = SquadV1Processor().get_examples_from_dataset(tfds_examples["validation"])
features = squad_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
......
This diff is collapsed.
......@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor):
else:
is_impossible = False
if not is_impossible and is_training:
if (len(qa["answers"]) != 1):
raise ValueError(
"For training, each question should have exactly 1 answer.")
answer = qa["answers"][0]
answer_text = answer['text']
start_position_character = answer['answer_start']
if not is_impossible:
if is_training:
answer = qa["answers"][0]
answer_text = answer['text']
start_position_character = answer['answer_start']
else:
answers = qa["answers"]
example = SquadExample(
qas_id=qas_id,
......@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor):
answer_text=answer_text,
start_position_character=start_position_character,
title=title,
is_impossible=is_impossible
is_impossible=is_impossible,
answers=answers
)
examples.append(example)
......@@ -352,6 +353,7 @@ class SquadExample(object):
answer_text,
start_position_character,
title,
answers=None,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
......@@ -359,6 +361,7 @@ class SquadExample(object):
self.answer_text = answer_text
self.title = title
self.is_impossible = is_impossible
self.answers = answers
self.start_position, self.end_position = 0, 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment