Commit 289cf4d2 authored by VictorSanh's avatar VictorSanh Committed by Lysandre Debut
Browse files

change default for XNLI: dev --> test

parent cb7b77a8
...@@ -270,7 +270,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -270,7 +270,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
output_mode = output_modes[task] output_mode = output_modes[task]
# Load data features from cache or dataset file # Load data features from cache or dataset file
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_{}'.format( cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_{}'.format(
'dev' if evaluate else 'train', 'test' if evaluate else 'train',
list(filter(None, args.model_name_or_path.split('/'))).pop(), list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task), str(task),
...@@ -281,7 +281,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -281,7 +281,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
else: else:
logger.info("Creating features from dataset file at %s", args.data_dir) logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels() label_list = processor.get_labels()
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) examples = processor.get_test_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples, features = convert_examples_to_features(examples,
tokenizer, tokenizer,
label_list=label_list, label_list=label_list,
...@@ -341,7 +341,7 @@ def main(): ...@@ -341,7 +341,7 @@ def main():
parser.add_argument("--do_train", action='store_true', parser.add_argument("--do_train", action='store_true',
help="Whether to run training.") help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.") help="Whether to run eval on the test set.")
parser.add_argument("--evaluate_during_training", action='store_true', parser.add_argument("--evaluate_during_training", action='store_true',
help="Rul evaluation during training at each logging step.") help="Rul evaluation during training at each logging step.")
parser.add_argument("--do_lower_case", action='store_true', parser.add_argument("--do_lower_case", action='store_true',
......
...@@ -50,9 +50,9 @@ class XnliProcessor(DataProcessor): ...@@ -50,9 +50,9 @@ class XnliProcessor(DataProcessor):
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_dev_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.dev.tsv")) lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
...@@ -60,7 +60,7 @@ class XnliProcessor(DataProcessor): ...@@ -60,7 +60,7 @@ class XnliProcessor(DataProcessor):
language = line[0] language = line[0]
if language != self.language: if language != self.language:
continue continue
guid = "%s-%s" % ('dev', i) guid = "%s-%s" % ('test', i)
text_a = line[6] text_a = line[6]
text_b = line[7] text_b = line[7]
label = line[1] label = line[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment