Unverified Commit 49296533 authored by Zhangyx's avatar Zhangyx Committed by GitHub
Browse files

Adds predict stage for glue tasks, and generate result files which can be...


Adds predict stage for glue tasks, and generate result files which can be submitted to gluebenchmark.com (#4463)

* Adds predict stage for glue tasks, and generate result files which could be submitted to gluebenchmark.com website.

* Use Split enum + always output the label name
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 271bedb4
...@@ -419,7 +419,7 @@ def main(): ...@@ -419,7 +419,7 @@ def main():
logger.info("Training/evaluation parameters %s", args) logger.info("Training/evaluation parameters %s", args)
# Prepare dataset for the GLUE task # Prepare dataset for the GLUE task
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True) eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
if args.data_subset > 0: if args.data_subset > 0:
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset))))) eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
......
...@@ -135,7 +135,8 @@ def main(): ...@@ -135,7 +135,8 @@ def main():
# Get datasets # Get datasets
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None
test_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="test") if training_args.do_predict else None
def compute_metrics(p: EvalPrediction) -> Dict: def compute_metrics(p: EvalPrediction) -> Dict:
if output_mode == "classification": if output_mode == "classification":
...@@ -165,7 +166,7 @@ def main(): ...@@ -165,7 +166,7 @@ def main():
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
# Evaluation # Evaluation
results = {} eval_results = {}
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
...@@ -173,10 +174,10 @@ def main(): ...@@ -173,10 +174,10 @@ def main():
eval_datasets = [eval_dataset] eval_datasets = [eval_dataset]
if data_args.task_name == "mnli": if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True)) eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev"))
for eval_dataset in eval_datasets: for eval_dataset in eval_datasets:
result = trainer.evaluate(eval_dataset=eval_dataset) eval_result = trainer.evaluate(eval_dataset=eval_dataset)
output_eval_file = os.path.join( output_eval_file = os.path.join(
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
...@@ -184,13 +185,38 @@ def main(): ...@@ -184,13 +185,38 @@ def main():
if trainer.is_world_master(): if trainer.is_world_master():
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
for key, value in result.items(): for key, value in eval_result.items():
logger.info(" %s = %s", key, value) logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value)) writer.write("%s = %s\n" % (key, value))
results.update(result) eval_results.update(eval_result)
return results if training_args.do_predict:
logging.info("*** Test ***")
test_datasets = [test_dataset]
if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
test_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test"))
for test_dataset in test_datasets:
predictions = trainer.predict(test_dataset=test_dataset).predictions
if output_mode == "classification":
predictions = np.argmax(predictions, axis=1)
output_test_file = os.path.join(
training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
)
if trainer.is_world_master():
with open(output_test_file, "w") as writer:
logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
if output_mode == "regression":
writer.write("%d\t%3.3f\n" % (index, item))
else:
item = test_dataset.get_labels()[item]
writer.write("%d\t%s\n" % (index, item))
return eval_results
def _mp_fn(index): def _mp_fn(index):
......
...@@ -2,7 +2,8 @@ import logging ...@@ -2,7 +2,8 @@ import logging
import os import os
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional from enum import Enum
from typing import List, Optional, Union
import torch import torch
from filelock import FileLock from filelock import FileLock
...@@ -47,6 +48,12 @@ class GlueDataTrainingArguments: ...@@ -47,6 +48,12 @@ class GlueDataTrainingArguments:
self.task_name = self.task_name.lower() self.task_name = self.task_name.lower()
class Split(Enum):
train = "train"
dev = "dev"
test = "test"
class GlueDataset(Dataset): class GlueDataset(Dataset):
""" """
This will be superseded by a framework-agnostic approach This will be superseded by a framework-agnostic approach
...@@ -62,16 +69,21 @@ class GlueDataset(Dataset): ...@@ -62,16 +69,21 @@ class GlueDataset(Dataset):
args: GlueDataTrainingArguments, args: GlueDataTrainingArguments,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
limit_length: Optional[int] = None, limit_length: Optional[int] = None,
evaluate=False, mode: Union[str, Split] = Split.train,
): ):
self.args = args self.args = args
processor = glue_processors[args.task_name]() self.processor = glue_processors[args.task_name]()
self.output_mode = glue_output_modes[args.task_name] self.output_mode = glue_output_modes[args.task_name]
if isinstance(mode, str):
try:
mode = Split[mode]
except KeyError:
raise KeyError("mode is not a valid split name")
# Load data features from cache or dataset file # Load data features from cache or dataset file
cached_features_file = os.path.join( cached_features_file = os.path.join(
args.data_dir, args.data_dir,
"cached_{}_{}_{}_{}".format( "cached_{}_{}_{}_{}".format(
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
), ),
) )
...@@ -88,7 +100,7 @@ class GlueDataset(Dataset): ...@@ -88,7 +100,7 @@ class GlueDataset(Dataset):
) )
else: else:
logger.info(f"Creating features from dataset file at {args.data_dir}") logger.info(f"Creating features from dataset file at {args.data_dir}")
label_list = processor.get_labels() label_list = self.processor.get_labels()
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in ( if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
RobertaTokenizer, RobertaTokenizer,
RobertaTokenizerFast, RobertaTokenizerFast,
...@@ -96,11 +108,12 @@ class GlueDataset(Dataset): ...@@ -96,11 +108,12 @@ class GlueDataset(Dataset):
): ):
# HACK(label indices are swapped in RoBERTa pretrained model) # HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1] label_list[1], label_list[2] = label_list[2], label_list[1]
examples = ( if mode == Split.dev:
processor.get_dev_examples(args.data_dir) examples = self.processor.get_dev_examples(args.data_dir)
if evaluate elif mode == Split.test:
else processor.get_train_examples(args.data_dir) examples = self.processor.get_test_examples(args.data_dir)
) else:
examples = self.processor.get_train_examples(args.data_dir)
if limit_length is not None: if limit_length is not None:
examples = examples[:limit_length] examples = examples[:limit_length]
self.features = glue_convert_examples_to_features( self.features = glue_convert_examples_to_features(
...@@ -122,3 +135,6 @@ class GlueDataset(Dataset): ...@@ -122,3 +135,6 @@ class GlueDataset(Dataset):
def __getitem__(self, i) -> InputFeatures: def __getitem__(self, i) -> InputFeatures:
return self.features[i] return self.features[i]
def get_labels(self):
return self.processor.get_labels()
...@@ -126,7 +126,9 @@ def _glue_convert_examples_to_features( ...@@ -126,7 +126,9 @@ def _glue_convert_examples_to_features(
label_map = {label: i for i, label in enumerate(label_list)} label_map = {label: i for i, label in enumerate(label_list)}
def label_from_example(example: InputExample) -> Union[int, float]: def label_from_example(example: InputExample) -> Union[int, float, None]:
if example.label is None:
return None
if output_mode == "classification": if output_mode == "classification":
return label_map[example.label] return label_map[example.label]
elif output_mode == "regression": elif output_mode == "regression":
...@@ -180,12 +182,16 @@ class MrpcProcessor(DataProcessor): ...@@ -180,12 +182,16 @@ class MrpcProcessor(DataProcessor):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training, dev and test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
...@@ -193,7 +199,7 @@ class MrpcProcessor(DataProcessor): ...@@ -193,7 +199,7 @@ class MrpcProcessor(DataProcessor):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = line[3] text_a = line[3]
text_b = line[4] text_b = line[4]
label = line[0] label = None if set_type == "test" else line[0]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -218,12 +224,16 @@ class MnliProcessor(DataProcessor): ...@@ -218,12 +224,16 @@ class MnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["contradiction", "entailment", "neutral"] return ["contradiction", "entailment", "neutral"]
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training, dev and test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
...@@ -231,7 +241,7 @@ class MnliProcessor(DataProcessor): ...@@ -231,7 +241,7 @@ class MnliProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
text_a = line[8] text_a = line[8]
text_b = line[9] text_b = line[9]
label = line[-1] label = None if set_type.startswith("test") else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -241,7 +251,11 @@ class MnliMismatchedProcessor(MnliProcessor): ...@@ -241,7 +251,11 @@ class MnliMismatchedProcessor(MnliProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
...@@ -264,17 +278,25 @@ class ColaProcessor(DataProcessor): ...@@ -264,17 +278,25 @@ class ColaProcessor(DataProcessor):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training, dev and test sets."""
test_mode = set_type == "test"
if test_mode:
lines = lines[1:]
text_index = 1 if test_mode else 3
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = line[3] text_a = line[text_index]
label = line[1] label = None if test_mode else line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
...@@ -299,19 +321,23 @@ class Sst2Processor(DataProcessor): ...@@ -299,19 +321,23 @@ class Sst2Processor(DataProcessor):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training, dev and test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = line[0] text_a = line[0]
label = line[1] label = None if set_type == "test" else line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
...@@ -336,12 +362,16 @@ class StsbProcessor(DataProcessor): ...@@ -336,12 +362,16 @@ class StsbProcessor(DataProcessor):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return [None] return [None]
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training, dev and test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
...@@ -349,7 +379,7 @@ class StsbProcessor(DataProcessor): ...@@ -349,7 +379,7 @@ class StsbProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
text_a = line[7] text_a = line[7]
text_b = line[8] text_b = line[8]
label = line[-1] label = None if set_type == "test" else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -374,21 +404,28 @@ class QqpProcessor(DataProcessor): ...@@ -374,21 +404,28 @@ class QqpProcessor(DataProcessor):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training, dev and test sets."""
test_mode = set_type == "test"
q1_index = 1 if test_mode else 3
q2_index = 2 if test_mode else 4
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
try: try:
text_a = line[3] text_a = line[q1_index]
text_b = line[4] text_b = line[q2_index]
label = line[5] label = None if test_mode else line[5]
except IndexError: except IndexError:
continue continue
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
...@@ -413,14 +450,18 @@ class QnliProcessor(DataProcessor): ...@@ -413,14 +450,18 @@ class QnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["entailment", "not_entailment"] return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training, dev and test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
...@@ -428,7 +469,7 @@ class QnliProcessor(DataProcessor): ...@@ -428,7 +469,7 @@ class QnliProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = line[-1] label = None if set_type == "test" else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -453,12 +494,16 @@ class RteProcessor(DataProcessor): ...@@ -453,12 +494,16 @@ class RteProcessor(DataProcessor):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["entailment", "not_entailment"] return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training, dev and test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
...@@ -466,7 +511,7 @@ class RteProcessor(DataProcessor): ...@@ -466,7 +511,7 @@ class RteProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = line[-1] label = None if set_type == "test" else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -491,12 +536,16 @@ class WnliProcessor(DataProcessor): ...@@ -491,12 +536,16 @@ class WnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training, dev and test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
...@@ -504,7 +553,7 @@ class WnliProcessor(DataProcessor): ...@@ -504,7 +553,7 @@ class WnliProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = line[-1] label = None if set_type == "test" else line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
......
...@@ -98,6 +98,10 @@ class DataProcessor: ...@@ -98,6 +98,10 @@ class DataProcessor:
"""Gets a collection of `InputExample`s for the dev set.""" """Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError() raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the test set."""
raise NotImplementedError()
def get_labels(self): def get_labels(self):
"""Gets the list of labels for this data set.""" """Gets the list of labels for this data set."""
raise NotImplementedError() raise NotImplementedError()
......
...@@ -30,7 +30,7 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -30,7 +30,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
data_args = GlueDataTrainingArguments( data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
) )
dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator() data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features) batch = data_collator.collate_batch(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long) self.assertEqual(batch["labels"].dtype, torch.long)
...@@ -41,7 +41,7 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -41,7 +41,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
data_args = GlueDataTrainingArguments( data_args = GlueDataTrainingArguments(
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
) )
dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator() data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features) batch = data_collator.collate_batch(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.float) self.assertEqual(batch["labels"].dtype, torch.float)
...@@ -93,7 +93,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -93,7 +93,7 @@ class TrainerIntegrationTest(unittest.TestCase):
data_args = GlueDataTrainingArguments( data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
) )
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
training_args = TrainingArguments(output_dir="./examples", no_cuda=True) training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset) trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment