Commit 5ab76b51 authored by Saurabh Saxena's avatar Saurabh Saxena Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 308659564
parent 1784a826
...@@ -250,6 +250,51 @@ class MrpcProcessor(DataProcessor): ...@@ -250,6 +250,51 @@ class MrpcProcessor(DataProcessor):
return examples return examples
class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "QQP"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
try:
text_a = line[3]
text_b = line[4]
label = line[5]
except IndexError:
continue
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b,
label=label))
return examples
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the CoLA data set (GLUE version)."""
......
...@@ -46,7 +46,7 @@ flags.DEFINE_string( ...@@ -46,7 +46,7 @@ flags.DEFINE_string(
"for the task.") "for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "SST-2", "XNLI"], ["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI"],
"The name of the task to train BERT classifier.") "The name of the task to train BERT classifier.")
# BERT Squad task specific flags. # BERT Squad task specific flags.
...@@ -143,6 +143,7 @@ def generate_classifier_dataset(): ...@@ -143,6 +143,7 @@ def generate_classifier_dataset():
"mnli": classifier_data_lib.MnliProcessor, "mnli": classifier_data_lib.MnliProcessor,
"mrpc": classifier_data_lib.MrpcProcessor, "mrpc": classifier_data_lib.MrpcProcessor,
"qnli": classifier_data_lib.QnliProcessor, "qnli": classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor,
"sst-2": classifier_data_lib.SstProcessor, "sst-2": classifier_data_lib.SstProcessor,
"xnli": classifier_data_lib.XnliProcessor, "xnli": classifier_data_lib.XnliProcessor,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment