"src/sdk/pycli/vscode:/vscode.git/clone" did not exist on "cd3a912a07c2187fcd9e81205da35b97746b76f3"
Unverified Commit 49296533 authored by Zhangyx's avatar Zhangyx Committed by GitHub
Browse files

Adds predict stage for glue tasks, and generate result files which can be...


Adds predict stage for glue tasks, and generate result files which can be submitted to gluebenchmark.com (#4463)

* Adds predict stage for glue tasks, and generate result files which could be submitted to gluebenchmark.com website.

* Use Split enum + always output the label name
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 271bedb4
......@@ -419,7 +419,7 @@ def main():
logger.info("Training/evaluation parameters %s", args)
# Prepare dataset for the GLUE task
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True)
eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
if args.data_subset > 0:
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
......
......@@ -135,7 +135,8 @@ def main():
# Get datasets
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None
test_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="test") if training_args.do_predict else None
def compute_metrics(p: EvalPrediction) -> Dict:
if output_mode == "classification":
......@@ -165,7 +166,7 @@ def main():
tokenizer.save_pretrained(training_args.output_dir)
# Evaluation
results = {}
eval_results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
......@@ -173,10 +174,10 @@ def main():
eval_datasets = [eval_dataset]
if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True))
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev"))
for eval_dataset in eval_datasets:
result = trainer.evaluate(eval_dataset=eval_dataset)
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
output_eval_file = os.path.join(
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
......@@ -184,13 +185,38 @@ def main():
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
for key, value in result.items():
for key, value in eval_result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result)
eval_results.update(eval_result)
return results
if training_args.do_predict:
logging.info("*** Test ***")
test_datasets = [test_dataset]
if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
test_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test"))
for test_dataset in test_datasets:
predictions = trainer.predict(test_dataset=test_dataset).predictions
if output_mode == "classification":
predictions = np.argmax(predictions, axis=1)
output_test_file = os.path.join(
training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
)
if trainer.is_world_master():
with open(output_test_file, "w") as writer:
logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
if output_mode == "regression":
writer.write("%d\t%3.3f\n" % (index, item))
else:
item = test_dataset.get_labels()[item]
writer.write("%d\t%s\n" % (index, item))
return eval_results
def _mp_fn(index):
......
......@@ -2,7 +2,8 @@ import logging
import os
import time
from dataclasses import dataclass, field
from typing import List, Optional
from enum import Enum
from typing import List, Optional, Union
import torch
from filelock import FileLock
......@@ -47,6 +48,12 @@ class GlueDataTrainingArguments:
self.task_name = self.task_name.lower()
class Split(Enum):
train = "train"
dev = "dev"
test = "test"
class GlueDataset(Dataset):
"""
This will be superseded by a framework-agnostic approach
......@@ -62,16 +69,21 @@ class GlueDataset(Dataset):
args: GlueDataTrainingArguments,
tokenizer: PreTrainedTokenizer,
limit_length: Optional[int] = None,
evaluate=False,
mode: Union[str, Split] = Split.train,
):
self.args = args
processor = glue_processors[args.task_name]()
self.processor = glue_processors[args.task_name]()
self.output_mode = glue_output_modes[args.task_name]
if isinstance(mode, str):
try:
mode = Split[mode]
except KeyError:
raise KeyError("mode is not a valid split name")
# Load data features from cache or dataset file
cached_features_file = os.path.join(
args.data_dir,
"cached_{}_{}_{}_{}".format(
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
),
)
......@@ -88,7 +100,7 @@ class GlueDataset(Dataset):
)
else:
logger.info(f"Creating features from dataset file at {args.data_dir}")
label_list = processor.get_labels()
label_list = self.processor.get_labels()
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
RobertaTokenizer,
RobertaTokenizerFast,
......@@ -96,11 +108,12 @@ class GlueDataset(Dataset):
):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
examples = (
processor.get_dev_examples(args.data_dir)
if evaluate
else processor.get_train_examples(args.data_dir)
)
if mode == Split.dev:
examples = self.processor.get_dev_examples(args.data_dir)
elif mode == Split.test:
examples = self.processor.get_test_examples(args.data_dir)
else:
examples = self.processor.get_train_examples(args.data_dir)
if limit_length is not None:
examples = examples[:limit_length]
self.features = glue_convert_examples_to_features(
......@@ -122,3 +135,6 @@ class GlueDataset(Dataset):
def __getitem__(self, i) -> InputFeatures:
return self.features[i]
def get_labels(self):
return self.processor.get_labels()
......@@ -126,7 +126,9 @@ def _glue_convert_examples_to_features(
label_map = {label: i for i, label in enumerate(label_list)}
def label_from_example(example: InputExample) -> Union[int, float]:
def label_from_example(example: InputExample) -> Union[int, float, None]:
if example.label is None:
return None
if output_mode == "classification":
return label_map[example.label]
elif output_mode == "regression":
......@@ -180,12 +182,16 @@ class MrpcProcessor(DataProcessor):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
......@@ -193,7 +199,7 @@ class MrpcProcessor(DataProcessor):
guid = "%s-%s" % (set_type, i)
text_a = line[3]
text_b = line[4]
label = line[0]
label = None if set_type == "test" else line[0]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -218,12 +224,16 @@ class MnliProcessor(DataProcessor):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
......@@ -231,7 +241,7 @@ class MnliProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0])
text_a = line[8]
text_b = line[9]
label = line[-1]
label = None if set_type.startswith("test") else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -241,7 +251,11 @@ class MnliMismatchedProcessor(MnliProcessor):
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
class ColaProcessor(DataProcessor):
......@@ -264,17 +278,25 @@ class ColaProcessor(DataProcessor):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training, dev and test sets."""
test_mode = set_type == "test"
if test_mode:
lines = lines[1:]
text_index = 1 if test_mode else 3
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = line[3]
label = line[1]
text_a = line[text_index]
label = None if test_mode else line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
......@@ -299,19 +321,23 @@ class Sst2Processor(DataProcessor):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[0]
label = line[1]
label = None if set_type == "test" else line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
......@@ -336,12 +362,16 @@ class StsbProcessor(DataProcessor):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return [None]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
......@@ -349,7 +379,7 @@ class StsbProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0])
text_a = line[7]
text_b = line[8]
label = line[-1]
label = None if set_type == "test" else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -374,21 +404,28 @@ class QqpProcessor(DataProcessor):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training, dev and test sets."""
test_mode = set_type == "test"
q1_index = 1 if test_mode else 3
q2_index = 2 if test_mode else 4
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
try:
text_a = line[3]
text_b = line[4]
label = line[5]
text_a = line[q1_index]
text_b = line[q2_index]
label = None if test_mode else line[5]
except IndexError:
continue
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
......@@ -413,14 +450,18 @@ class QnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
......@@ -428,7 +469,7 @@ class QnliProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
label = None if set_type == "test" else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -453,12 +494,16 @@ class RteProcessor(DataProcessor):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
......@@ -466,7 +511,7 @@ class RteProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
label = None if set_type == "test" else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -491,12 +536,16 @@ class WnliProcessor(DataProcessor):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
......@@ -504,7 +553,7 @@ class WnliProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
label = None if set_type == "test" else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......
......@@ -98,6 +98,10 @@ class DataProcessor:
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the test set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
......
......@@ -30,7 +30,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long)
......@@ -41,7 +41,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
data_args = GlueDataTrainingArguments(
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.float)
......@@ -93,7 +93,7 @@ class TrainerIntegrationTest(unittest.TestCase):
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment