Commit 164c794e authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

New SQuAD API for distillation script

parent 801f2ac8
......@@ -15,7 +15,6 @@
# limitations under the License.
""" This is the exact same script as `examples/run_squad.py` (as of 2019, October 4th) with an additional and optional step of distillation."""
import argparse
import glob
import logging
......@@ -26,7 +25,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
......@@ -46,22 +45,14 @@ from transformers import (
XLNetForQuestionAnswering,
XLNetTokenizer,
get_linear_schedule_with_warmup,
squad_convert_examples_to_features,
)
from ..utils_squad import (
RawResult,
RawResultExtended,
convert_examples_to_features,
read_squad_examples,
write_predictions,
write_predictions_extended,
from transformers.data.metrics.squad_metrics import (
compute_predictions_log_probs,
compute_predictions_logits,
squad_evaluate,
)
# The follwing import is the official SQuAD evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library
# We've added it here for automated tests (see examples/test_examples.py file)
from ..utils_squad_evaluate import EVAL_OPTS
from ..utils_squad_evaluate import main as evaluate_on_squad
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
try:
......@@ -69,7 +60,6 @@ try:
except ImportError:
from tensorboardX import SummaryWriter
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
......@@ -294,20 +284,31 @@ def evaluate(args, model, tokenizer, prefix=""):
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure
result = RawResultExtended(
unique_id=unique_id,
start_top_log_probs=to_list(outputs[0][i]),
start_top_index=to_list(outputs[1][i]),
end_top_log_probs=to_list(outputs[2][i]),
end_top_index=to_list(outputs[3][i]),
cls_logits=to_list(outputs[4][i]),
output = [to_list(output[i]) for output in outputs]
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
# models only use two.
if len(output) >= 5:
start_logits = output[0]
start_top_index = output[1]
end_logits = output[2]
end_top_index = output[3]
cls_logits = output[4]
result = SquadResult(
unique_id,
start_logits,
end_logits,
start_top_index=start_top_index,
end_top_index=end_top_index,
cls_logits=cls_logits,
)
else:
result = RawResult(
unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])
)
start_logits, end_logits = output
result = SquadResult(unique_id, start_logits, end_logits)
all_results.append(result)
# Compute predictions
......@@ -320,7 +321,7 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure
write_predictions_extended(
predictions = compute_predictions_log_probs(
examples,
features,
all_results,
......@@ -337,7 +338,7 @@ def evaluate(args, model, tokenizer, prefix=""):
args.verbose_logging,
)
else:
write_predictions(
predictions = compute_predictions_logits(
examples,
features,
all_results,
......@@ -350,13 +351,11 @@ def evaluate(args, model, tokenizer, prefix=""):
args.verbose_logging,
args.version_2_with_negative,
args.null_score_diff_threshold,
tokenizer,
)
# Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(
data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
)
results = evaluate_on_squad(evaluate_options)
# Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions)
return results
......@@ -368,59 +367,51 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(
os.path.dirname(input_file),
"cached_{}_{}_{}".format(
"cached_distillation_{}_{}_{}".format(
"dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
),
)
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
features_and_dataset = torch.load(cached_features_file)
try:
features, dataset, examples = (
features_and_dataset["features"],
features_and_dataset["dataset"],
features_and_dataset["examples"],
)
except KeyError:
raise DeprecationWarning(
"You seem to be loading features from an older version of this script please delete the "
"file %s in order for it to be created again" % cached_features_file
)
else:
logger.info("Creating features from dataset file at %s", input_file)
examples = read_squad_examples(
input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative
)
features = convert_examples_to_features(
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
if evaluate:
examples = processor.get_dev_examples(None, filename=args.predict_file)
else:
examples = processor.get_train_examples(None, filename=args.train_file)
features, dataset = squad_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
return_dataset="pt",
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if evaluate:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(
all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
)
else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
dataset = TensorDataset(
all_input_ids,
all_input_mask,
all_segment_ids,
all_start_positions,
all_end_positions,
all_cls_index,
all_p_mask,
)
if output_examples:
return dataset, examples, features
return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment