Commit 80a21374 authored by stephenwu's avatar stephenwu
Browse files

added RTE preprocessor

parent b46db8ad
...@@ -120,6 +120,15 @@ class DataProcessor(object): ...@@ -120,6 +120,15 @@ class DataProcessor(object):
lines.append(line) lines.append(line)
return lines return lines
@classmethod
def _read_jsonl(self, input_path):
"""Reads a json line file."""
with tf.io.gfile.GFile(input_path, "r") as f:
lines = []
for json_str in f:
lines.append(json.loads(json_str))
return lines
class AxProcessor(DataProcessor): class AxProcessor(DataProcessor):
"""Processor for the AX dataset (GLUE diagnostics dataset).""" """Processor for the AX dataset (GLUE diagnostics dataset)."""
...@@ -1277,7 +1286,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -1277,7 +1286,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
return feature return feature
class AXgProcessor(DataProcessor): class AXgProcessor(DataProcessor):
"""Processor for the AXg dataset (GLUE diagnostics dataset).""" """Processor for the AXg dataset (SuperGLUE diagnostics dataset)."""
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -1298,20 +1307,56 @@ class AXgProcessor(DataProcessor): ...@@ -1298,20 +1307,56 @@ class AXgProcessor(DataProcessor):
examples = [] examples = []
for line in lines: for line in lines:
guid = "%s-%s" % (set_type, self.process_text_fn(str(line['idx']))) guid = "%s-%s" % (set_type, self.process_text_fn(str(line['idx'])))
text_a = self.process_text_fn(line["hypothesis"]) text_a = self.process_text_fn(line["premise"])
text_b = self.process_text_fn(line["premise"]) text_b = self.process_text_fn(line["hypothesis"])
label = self.process_text_fn(line["label"]) label = self.process_text_fn(line["label"])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def _read_jsonl(self, input_path): class RTESuperGLUEProcessor(DataProcessor):
with tf.io.gfile.GFile(input_path, "r") as f: """Processor for the RTE dataset (SuperGLUE version)."""
lines = []
for json_str in f: def get_train_examples(self, data_dir):
lines.append(json.loads(json_str)) """See base class."""
return lines return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "train.jsonl")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "val.jsonl")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "test.jsonl")), "test")
def get_labels(self):
"""See base class."""
# All datasets are converted to 2-class split, where for 3-class datasets we
# collapse neutral and contradiction into not_entailment.
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "RTESuperGLUE"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line['premise'])
text_b = self.process_text_fn(line['hypothesis'])
if set_type == "test":
label = "entailment"
else:
label = self.process_text_fn(line['label'])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def file_based_convert_examples_to_features(examples, def file_based_convert_examples_to_features(examples,
label_list, label_list,
......
...@@ -49,7 +49,8 @@ flags.DEFINE_string( ...@@ -49,7 +49,8 @@ flags.DEFINE_string(
flags.DEFINE_enum( flags.DEFINE_enum(
"classification_task_name", "MNLI", [ "classification_task_name", "MNLI", [
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE", "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X", "AX-g" "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X", "AX-g",
"RTE-SuperGLUE"
], "The name of the task to train BERT classifier. The " ], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format " "difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english " "of input tsv files; 2. the dev set for XTREME is english "
...@@ -240,7 +241,9 @@ def generate_classifier_dataset(): ...@@ -240,7 +241,9 @@ def generate_classifier_dataset():
translated_data_dir=FLAGS.translated_input_data_dir, translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev), only_use_en_dev=FLAGS.only_use_en_dev),
"ax-g": "ax-g":
classifier_data_lib.AXgProcessor classifier_data_lib.AXgProcessor,
"rte-superglue":
classifier_data_lib.RTESuperGLUEProcessor
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment