Commit 636ca66f authored by Vincent Etter's avatar Vincent Etter Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 330596737
parent 095bc035
...@@ -124,6 +124,54 @@ class DataProcessor(object): ...@@ -124,6 +124,54 @@ class DataProcessor(object):
return lines return lines
class AxProcessor(DataProcessor):
"""Processor for the AX dataset (GLUE diagnostics dataset)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "AX"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
text_a_index = 1 if set_type == "test" else 8
text_b_index = 2 if set_type == "test" else 9
examples = []
for i, line in enumerate(lines):
# Skip header.
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[text_a_index])
text_b = self.process_text_fn(line[text_b_index])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the CoLA data set (GLUE version)."""
......
...@@ -51,7 +51,7 @@ flags.DEFINE_string( ...@@ -51,7 +51,7 @@ flags.DEFINE_string(
"for the task.") "for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE", ["AX", "COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI",
"XTREME-PAWS-X"], "XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The " "The name of the task to train BERT classifier. The "
...@@ -182,6 +182,8 @@ def generate_classifier_dataset(): ...@@ -182,6 +182,8 @@ def generate_classifier_dataset():
max_seq_length=FLAGS.max_seq_length) max_seq_length=FLAGS.max_seq_length)
else: else:
processors = { processors = {
"ax":
classifier_data_lib.AxProcessor,
"cola": "cola":
classifier_data_lib.ColaProcessor, classifier_data_lib.ColaProcessor,
"mnli": "mnli":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment