Commit 52515dc3 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 322197751
parent 57253ebc
...@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor): ...@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor):
return "COLA" return "COLA"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
# Only the test set has a header # Only the test set has a header.
if set_type == "test" and i == 0: if set_type == "test" and i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor): ...@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor):
class MnliProcessor(DataProcessor): class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version).""" """Processor for the MultiNLI data set (GLUE version)."""
def __init__(self,
mnli_type="matched",
process_text_fn=tokenization.convert_to_unicode):
super(MnliProcessor, self).__init__(process_text_fn)
if mnli_type not in ("matched", "mismatched"):
raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
self.mnli_type = mnli_type
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
...@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor): ...@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( if self.mnli_type == "matched":
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), return self._create_examples(
"dev_matched") self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_mismatched")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( if self.mnli_type == "matched":
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor): ...@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor):
return "MNLI" return "MNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0])) guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
...@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor): ...@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor):
return "MRPC" return "MRPC"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor): ...@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:]) self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor): ...@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:]) self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "dev-%d" % i guid = "dev-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor): ...@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:] lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "test-%d" % i guid = "test-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
...@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor): ...@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor):
return "QNLI" return "QNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, 1) guid = "%s-%s" % (set_type, 1)
...@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor): ...@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor):
return "QQP" return "QQP"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
...@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor): ...@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor):
return "RTE" return "RTE"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
...@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor): ...@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor):
return "SST-2" return "SST-2"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor): ...@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor):
return "STS-B" return "STS-B"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
...@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor): ...@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return "TFDS_" + self.dataset_name return "TFDS_" + self.dataset_name
def _create_examples(self, split_name, set_type): def _create_examples(self, split_name, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
if split_name not in self.dataset: if split_name not in self.dataset:
raise ValueError("Split {} not available.".format(split_name)) raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator() dataset = self.dataset[split_name].as_numpy_iterator()
...@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor): ...@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor):
return "WNLI" return "WNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training/dev/test sets."""
examples = [] examples = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
...@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor): ...@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor):
"multinli.train.%s.tsv" % language))[1:]) "multinli.train.%s.tsv" % language))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor): ...@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "dev-%d" % i guid = "dev-%d" % i
...@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor): ...@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv")) lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages} examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "test-%d" % i guid = "test-%d" % i
...@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor): ...@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv")) lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor): ...@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor):
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv")) lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "dev-%d" % i guid = "dev-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor): ...@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv")) lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "test-%d" % i guid = "test-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor): ...@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor):
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv")) lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor): ...@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv")) lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = "dev-%d" % i guid = "dev-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor): ...@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv")) lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
guid = f"test-{i}" guid = f"test-{i}"
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples, ...@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
tf.io.gfile.makedirs(os.path.dirname(output_file)) tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file) writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples): for ex_index, example in enumerate(examples):
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples)) logging.info("Writing example %d of %d", ex_index, len(examples))
......
...@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI", ...@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
"only and for XNLI is all languages combined. Same for " "only and for XNLI is all languages combined. Same for "
"PAWS-X.") "PAWS-X.")
# XNLI task specific flag. # MNLI task-specific flag.
flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"xnli_language", "en", "xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data " "Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# PAWS-X task specific flag. # PAWS-X task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
"pawsx_language", "en", "pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data " "Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# Retrieva task specific flags # Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"], flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring") "The name of sentence retrieval task for scoring")
# Tagging task specific flags # Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"], flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.") "The name of BERT tagging (token classification) task.")
# BERT Squad task specific flags. # BERT Squad task-specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"squad_data_file", None, "squad_data_file", None,
"The input data file in for generating training data for BERT squad task.") "The input data file in for generating training data for BERT squad task.")
...@@ -179,7 +184,8 @@ def generate_classifier_dataset(): ...@@ -179,7 +184,8 @@ def generate_classifier_dataset():
"cola": "cola":
classifier_data_lib.ColaProcessor, classifier_data_lib.ColaProcessor,
"mnli": "mnli":
classifier_data_lib.MnliProcessor, functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
"mrpc": "mrpc":
classifier_data_lib.MrpcProcessor, classifier_data_lib.MrpcProcessor,
"qnli": "qnli":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment