"...test_cli/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "9f352df0eb66a2c55b9fc40b02173a471c3e5ee4"
Commit 2e398eca authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #9722 from supersteph:RTESuperGLUE

PiperOrigin-RevId: 357217962
parents 0326425d 440e0eec
...@@ -120,6 +120,15 @@ class DataProcessor(object): ...@@ -120,6 +120,15 @@ class DataProcessor(object):
lines.append(line) lines.append(line)
return lines return lines
@classmethod
def _read_jsonl(cls, input_file):
"""Reads a json line file."""
with tf.io.gfile.GFile(input_file, "r") as f:
lines = []
for json_str in f:
lines.append(json.loads(json_str))
return lines
class AxProcessor(DataProcessor): class AxProcessor(DataProcessor):
"""Processor for the AX dataset (GLUE diagnostics dataset).""" """Processor for the AX dataset (GLUE diagnostics dataset)."""
...@@ -1278,7 +1287,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -1278,7 +1287,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
class AXgProcessor(DataProcessor): class AXgProcessor(DataProcessor):
"""Processor for the AXg dataset (GLUE diagnostics dataset).""" """Processor for the AXg dataset (SuperGLUE diagnostics dataset)."""
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -1299,19 +1308,57 @@ class AXgProcessor(DataProcessor): ...@@ -1299,19 +1308,57 @@ class AXgProcessor(DataProcessor):
examples = [] examples = []
for line in lines: for line in lines:
guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"]))) guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
text_a = self.process_text_fn(line["hypothesis"]) text_a = self.process_text_fn(line["premise"])
text_b = self.process_text_fn(line["premise"]) text_b = self.process_text_fn(line["hypothesis"])
label = self.process_text_fn(line["label"]) label = self.process_text_fn(line["label"])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def _read_jsonl(self, input_path):
with tf.io.gfile.GFile(input_path, "r") as f: class SuperGLUERTEProcessor(DataProcessor):
lines = [] """Processor for the RTE dataset (SuperGLUE version)."""
for json_str in f:
lines.append(json.loads(json_str)) def get_train_examples(self, data_dir):
return lines """See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "train.jsonl")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "val.jsonl")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "test.jsonl")), "test")
def get_labels(self):
"""See base class."""
# All datasets are converted to 2-class split, where for 3-class datasets we
# collapse neutral and contradiction into not_entailment.
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "RTESuperGLUE"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line["premise"])
text_b = self.process_text_fn(line["hypothesis"])
if set_type == "test":
label = "entailment"
else:
label = self.process_text_fn(line["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def file_based_convert_examples_to_features(examples, def file_based_convert_examples_to_features(examples,
......
...@@ -49,7 +49,8 @@ flags.DEFINE_string( ...@@ -49,7 +49,8 @@ flags.DEFINE_string(
flags.DEFINE_enum( flags.DEFINE_enum(
"classification_task_name", "MNLI", [ "classification_task_name", "MNLI", [
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE", "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X", "AX-g" "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
"AX-g", "SUPERGLUE-RTE"
], "The name of the task to train BERT classifier. The " ], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format " "difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english " "of input tsv files; 2. the dev set for XTREME is english "
...@@ -240,7 +241,9 @@ def generate_classifier_dataset(): ...@@ -240,7 +241,9 @@ def generate_classifier_dataset():
translated_data_dir=FLAGS.translated_input_data_dir, translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev), only_use_en_dev=FLAGS.only_use_en_dev),
"ax-g": "ax-g":
classifier_data_lib.AXgProcessor classifier_data_lib.AXgProcessor,
"superglue-rte":
classifier_data_lib.SuperGLUERTEProcessor
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
if task_name not in processors: if task_name not in processors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment