Commit 52515dc3 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 322197751
parent 57253ebc
......@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor):
return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
for i, line in enumerate(lines):
# Only the test set has a header.
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
......@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor):
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def __init__(self,
mnli_type="matched",
process_text_fn=tokenization.convert_to_unicode):
super(MnliProcessor, self).__init__(process_text_fn)
if mnli_type not in ("matched", "mismatched"):
raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
self.mnli_type = mnli_type
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
......@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_mismatched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
def get_labels(self):
"""See base class."""
......@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor):
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
......@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor):
return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
......@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
......@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
......@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
......@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor):
return "QNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, 1)
......@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor):
return "QQP"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
......@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor):
return "RTE"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
......@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor):
return "SST-2"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
......@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor):
return "STS-B"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
......@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return "TFDS_" + self.dataset_name
def _create_examples(self, split_name, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
if split_name not in self.dataset:
raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator()
......@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor):
return "WNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
......@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor):
"multinli.train.%s.tsv" % language))[1:])
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
......@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % i
......@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
......@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
......@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor):
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
......@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
......@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor):
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
......@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
......@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
......@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
for ex_index, example in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples))
......
......@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# XNLI task specific flag.
# MNLI task-specific flag.
flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task-specific flag.
flags.DEFINE_string(
"xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data "
"Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.")
# PAWS-X task specific flag.
# PAWS-X task-specific flag.
flags.DEFINE_string(
"pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.")
# Retrieva task specific flags
# Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# Tagging task specific flags
# Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.")
# BERT Squad task specific flags.
# BERT Squad task-specific flags.
flags.DEFINE_string(
"squad_data_file", None,
"The input data file in for generating training data for BERT squad task.")
......@@ -179,7 +184,8 @@ def generate_classifier_dataset():
"cola":
classifier_data_lib.ColaProcessor,
"mnli":
classifier_data_lib.MnliProcessor,
functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
"mrpc":
classifier_data_lib.MrpcProcessor,
"qnli":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment