Commit fda53f78 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 375854504
parent 50905fd2
......@@ -1316,6 +1316,92 @@ class AXgProcessor(DataProcessor):
return examples
class BoolQProcessor(DataProcessor):
"""Processor for the BoolQ dataset (SuperGLUE diagnostics dataset)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "train.jsonl")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "val.jsonl")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "test.jsonl")), "test")
def get_labels(self):
"""See base class."""
return ["True", "False"]
@staticmethod
def get_processor_name():
"""See base class."""
return "BoolQ"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for line in lines:
guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
text_a = self.process_text_fn(line["question"])
text_b = self.process_text_fn(line["passage"])
if set_type == "test":
label = "False"
else:
label = str(line["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class CBProcessor(DataProcessor):
"""Processor for the CB dataset (SuperGLUE diagnostics dataset)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "train.jsonl")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "val.jsonl")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "test.jsonl")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "neutral", "contradiction"]
@staticmethod
def get_processor_name():
"""See base class."""
return "CB"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for line in lines:
guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
text_a = self.process_text_fn(line["premise"])
text_b = self.process_text_fn(line["hypothesis"])
if set_type == "test":
label = "entailment"
else:
label = self.process_text_fn(line["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class SuperGLUERTEProcessor(DataProcessor):
"""Processor for the RTE dataset (SuperGLUE version)."""
......
......@@ -50,7 +50,7 @@ flags.DEFINE_enum(
"classification_task_name", "MNLI", [
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
"AX-g", "SUPERGLUE-RTE"
"AX-g", "SUPERGLUE-RTE", "CB", "BoolQ"
], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
......@@ -243,7 +243,11 @@ def generate_classifier_dataset():
"ax-g":
classifier_data_lib.AXgProcessor,
"superglue-rte":
classifier_data_lib.SuperGLUERTEProcessor
classifier_data_lib.SuperGLUERTEProcessor,
"cb":
classifier_data_lib.CBProcessor,
"boolq":
classifier_data_lib.BoolQProcessor,
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
......
......@@ -65,6 +65,8 @@ EVAL_METRIC_MAP = {
AXG_CLASS_NAMES = ['entailment', 'not_entailment']
RTE_CLASS_NAMES = ['entailment', 'not_entailment']
CB_CLASS_NAMES = ['entailment', 'neutral', 'contradiction']
BOOLQ_CLASS_NAMES = ['True', 'False']
def _override_exp_config_by_file(exp_config, exp_config_files):
......@@ -154,7 +156,9 @@ def _write_submission_file(task, seq_length):
write_fn = binary_helper.write_superglue_classification
write_fn_map = {
'RTE': functools.partial(write_fn, class_names=RTE_CLASS_NAMES),
'AX-g': functools.partial(write_fn, class_names=AXG_CLASS_NAMES)
'AX-g': functools.partial(write_fn, class_names=AXG_CLASS_NAMES),
'CB': functools.partial(write_fn, class_names=CB_CLASS_NAMES),
'BoolQ': functools.partial(write_fn, class_names=BOOLQ_CLASS_NAMES)
}
logging.info('Predicting %s', FLAGS.test_input_path)
write_fn_map[FLAGS.task_name](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment