Commit 164c794e authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

New SQuAD API for distillation script

parent 801f2ac8
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
""" This is the exact same script as `examples/run_squad.py` (as of 2019, October 4th) with an additional and optional step of distillation.""" """ This is the exact same script as `examples/run_squad.py` (as of 2019, October 4th) with an additional and optional step of distillation."""
import argparse import argparse
import glob import glob
import logging import logging
...@@ -26,7 +25,7 @@ import numpy as np ...@@ -26,7 +25,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
...@@ -46,22 +45,14 @@ from transformers import ( ...@@ -46,22 +45,14 @@ from transformers import (
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetTokenizer, XLNetTokenizer,
get_linear_schedule_with_warmup, get_linear_schedule_with_warmup,
squad_convert_examples_to_features,
) )
from transformers.data.metrics.squad_metrics import (
from ..utils_squad import ( compute_predictions_log_probs,
RawResult, compute_predictions_logits,
RawResultExtended, squad_evaluate,
convert_examples_to_features,
read_squad_examples,
write_predictions,
write_predictions_extended,
) )
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
# The follwing import is the official SQuAD evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library
# We've added it here for automated tests (see examples/test_examples.py file)
from ..utils_squad_evaluate import EVAL_OPTS
from ..utils_squad_evaluate import main as evaluate_on_squad
try: try:
...@@ -69,7 +60,6 @@ try: ...@@ -69,7 +60,6 @@ try:
except ImportError: except ImportError:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum( ALL_MODELS = sum(
...@@ -294,20 +284,31 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -294,20 +284,31 @@ def evaluate(args, model, tokenizer, prefix=""):
for i, example_index in enumerate(example_indices): for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()] eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id) unique_id = int(eval_feature.unique_id)
if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure output = [to_list(output[i]) for output in outputs]
result = RawResultExtended(
unique_id=unique_id, # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
start_top_log_probs=to_list(outputs[0][i]), # models only use two.
start_top_index=to_list(outputs[1][i]), if len(output) >= 5:
end_top_log_probs=to_list(outputs[2][i]), start_logits = output[0]
end_top_index=to_list(outputs[3][i]), start_top_index = output[1]
cls_logits=to_list(outputs[4][i]), end_logits = output[2]
end_top_index = output[3]
cls_logits = output[4]
result = SquadResult(
unique_id,
start_logits,
end_logits,
start_top_index=start_top_index,
end_top_index=end_top_index,
cls_logits=cls_logits,
) )
else: else:
result = RawResult( start_logits, end_logits = output
unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i]) result = SquadResult(unique_id, start_logits, end_logits)
)
all_results.append(result) all_results.append(result)
# Compute predictions # Compute predictions
...@@ -320,7 +321,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -320,7 +321,7 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ["xlnet", "xlm"]: if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
write_predictions_extended( predictions = compute_predictions_log_probs(
examples, examples,
features, features,
all_results, all_results,
...@@ -337,7 +338,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -337,7 +338,7 @@ def evaluate(args, model, tokenizer, prefix=""):
args.verbose_logging, args.verbose_logging,
) )
else: else:
write_predictions( predictions = compute_predictions_logits(
examples, examples,
features, features,
all_results, all_results,
...@@ -350,13 +351,11 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -350,13 +351,11 @@ def evaluate(args, model, tokenizer, prefix=""):
args.verbose_logging, args.verbose_logging,
args.version_2_with_negative, args.version_2_with_negative,
args.null_score_diff_threshold, args.null_score_diff_threshold,
tokenizer,
) )
# Evaluate with the official SQuAD script # Compute the F1 and exact scores.
evaluate_options = EVAL_OPTS( results = squad_evaluate(examples, predictions)
data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
)
results = evaluate_on_squad(evaluate_options)
return results return results
...@@ -368,59 +367,51 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -368,59 +367,51 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
input_file = args.predict_file if evaluate else args.train_file input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join( cached_features_file = os.path.join(
os.path.dirname(input_file), os.path.dirname(input_file),
"cached_{}_{}_{}".format( "cached_distillation_{}_{}_{}".format(
"dev" if evaluate else "train", "dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(), list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
), ),
) )
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file) features_and_dataset = torch.load(cached_features_file)
try:
features, dataset, examples = (
features_and_dataset["features"],
features_and_dataset["dataset"],
features_and_dataset["examples"],
)
except KeyError:
raise DeprecationWarning(
"You seem to be loading features from an older version of this script please delete the "
"file %s in order for it to be created again" % cached_features_file
)
else: else:
logger.info("Creating features from dataset file at %s", input_file) logger.info("Creating features from dataset file at %s", input_file)
examples = read_squad_examples( processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative if evaluate:
) examples = processor.get_dev_examples(None, filename=args.predict_file)
features = convert_examples_to_features( else:
examples = processor.get_train_examples(None, filename=args.train_file)
features, dataset = squad_convert_examples_to_features(
examples=examples, examples=examples,
tokenizer=tokenizer, tokenizer=tokenizer,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride, doc_stride=args.doc_stride,
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
is_training=not evaluate, is_training=not evaluate,
return_dataset="pt",
) )
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
if args.local_rank == 0 and not evaluate: if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if evaluate:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(
all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
)
else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
dataset = TensorDataset(
all_input_ids,
all_input_mask,
all_segment_ids,
all_start_positions,
all_end_positions,
all_cls_index,
all_p_mask,
)
if output_examples: if output_examples:
return dataset, examples, features return dataset, examples, features
return dataset return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment