Commit 5ad16f95 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 380296477
parent 19a49ae3
......@@ -135,18 +135,22 @@ class AxProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
train_mnli_dataset = tfds.load(
"glue/mnli", split="train", try_gcs=True).as_numpy_iterator()
return self._create_examples_tfds(train_mnli_dataset, "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
val_mnli_dataset = tfds.load(
"glue/mnli", split="validation_matched",
try_gcs=True).as_numpy_iterator()
return self._create_examples_tfds(val_mnli_dataset, "validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
test_ax_dataset = tfds.load(
"glue/ax", split="test", try_gcs=True).as_numpy_iterator()
return self._create_examples_tfds(test_ax_dataset, "test")
def get_labels(self):
"""See base class."""
......@@ -157,24 +161,20 @@ class AxProcessor(DataProcessor):
"""See base class."""
return "AX"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, dataset, set_type):
"""Creates examples for the training/dev/test sets."""
text_a_index = 1 if set_type == "test" else 8
text_b_index = 2 if set_type == "test" else 9
examples = []
for i, line in enumerate(lines):
# Skip header.
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[text_a_index])
text_b = self.process_text_fn(line[text_b_index])
if set_type == "test":
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
text_a = self.process_text_fn(example["hypothesis"])
text_b = self.process_text_fn(example["premise"])
if set_type != "test":
label = self.get_labels()[example["label"]]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples
......@@ -264,34 +264,28 @@ class MnliProcessor(DataProcessor):
mnli_type="matched",
process_text_fn=tokenization.convert_to_unicode):
super(MnliProcessor, self).__init__(process_text_fn)
self.dataset = tfds.load("glue/mnli", try_gcs=True)
if mnli_type not in ("matched", "mismatched"):
raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
self.mnli_type = mnli_type
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
return self._create_examples_tfds("validation_matched")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_mismatched")
return self._create_examples_tfds("validation_mismatched")
def get_test_examples(self, data_dir):
"""See base class."""
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
return self._create_examples_tfds("test_matched")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
return self._create_examples_tfds("test_mismatched")
def get_labels(self):
"""See base class."""
......@@ -302,21 +296,22 @@ class MnliProcessor(DataProcessor):
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/mnli", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
text_a = self.process_text_fn(example["hypothesis"])
text_b = self.process_text_fn(example["premise"])
if set_type != "test":
label = self.get_labels()[example["label"]]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples
......@@ -325,18 +320,15 @@ class MrpcProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
return self._create_examples_tfds("test")
def get_labels(self):
"""See base class."""
......@@ -347,21 +339,22 @@ class MrpcProcessor(DataProcessor):
"""See base class."""
return "MRPC"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/mrpc", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test":
label = "0"
else:
label = self.process_text_fn(line[0])
text_a = self.process_text_fn(example["sentence1"])
text_b = self.process_text_fn(example["sentence2"])
if set_type != "test":
label = str(example["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples
......@@ -449,18 +442,15 @@ class QnliProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
return self._create_examples_tfds("test")
def get_labels(self):
"""See base class."""
......@@ -471,23 +461,22 @@ class QnliProcessor(DataProcessor):
"""See base class."""
return "QNLI"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/qnli", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, 1)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
label = "entailment"
else:
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
text_a = self.process_text_fn(example["question"])
text_b = self.process_text_fn(example["sentence"])
if set_type != "test":
label = self.get_labels()[example["label"]]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples
......@@ -496,18 +485,15 @@ class QqpProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
return self._create_examples_tfds("test")
def get_labels(self):
"""See base class."""
......@@ -518,27 +504,22 @@ class QqpProcessor(DataProcessor):
"""See base class."""
return "QQP"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/qqp", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
if set_type == "test":
text_a = line[1]
text_b = line[2]
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
label = "0"
else:
# There appear to be some garbage lines in the train dataset.
try:
text_a = line[3]
text_b = line[4]
label = line[5]
except IndexError:
continue
text_a = self.process_text_fn(example["question1"])
text_b = self.process_text_fn(example["question2"])
if set_type != "test":
label = str(example["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples
......@@ -547,18 +528,15 @@ class RteProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
return self._create_examples_tfds("test")
def get_labels(self):
"""See base class."""
......@@ -571,21 +549,22 @@ class RteProcessor(DataProcessor):
"""See base class."""
return "RTE"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/rte", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
if set_type == "test":
label = "entailment"
else:
label = tokenization.convert_to_unicode(line[3])
text_a = self.process_text_fn(example["sentence1"])
text_b = self.process_text_fn(example["sentence2"])
if set_type != "test":
label = self.get_labels()[example["label"]]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples
......@@ -594,18 +573,15 @@ class SstProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
return self._create_examples_tfds("test")
def get_labels(self):
"""See base class."""
......@@ -616,21 +592,20 @@ class SstProcessor(DataProcessor):
"""See base class."""
return "SST-2"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/sst2", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
label = "0"
else:
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
text_a = self.process_text_fn(example["sentence"])
if set_type != "test":
label = str(example["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
InputExample(
guid=guid, text_a=text_a, text_b=None, label=label, weight=None))
return examples
......@@ -645,18 +620,33 @@ class StsBProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
return self._create_examples_tfds("test")
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/stsb", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
label = 0.0
text_a = self.process_text_fn(example["sentence1"])
text_b = self.process_text_fn(example["sentence2"])
if set_type != "test":
label = self.label_type(example["label"])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples
def get_labels(self):
"""See base class."""
......@@ -667,23 +657,6 @@ class StsBProcessor(DataProcessor):
"""See base class."""
return "STS-B"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[7])
text_b = tokenization.convert_to_unicode(line[8])
if set_type == "test":
label = 0.0
else:
label = self.label_type(tokenization.convert_to_unicode(line[9]))
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class TfdsProcessor(DataProcessor):
"""Processor for generic text classification and regression TFDS data set.
......@@ -818,18 +791,15 @@ class WnliProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
return self._create_examples_tfds("test")
def get_labels(self):
"""See base class."""
......@@ -840,21 +810,22 @@ class WnliProcessor(DataProcessor):
"""See base class."""
return "WNLI"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/wnli", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[3])
text_a = self.process_text_fn(example["sentence1"])
text_b = self.process_text_fn(example["sentence2"])
if set_type != "test":
label = str(example["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples
......
......@@ -173,6 +173,12 @@ flags.DEFINE_string(
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
if FLAGS.classification_task_name in [
"COLA", "WNLI", "SST-2", "MRPC", "QQP", "STS-B", "MNLI", "QNLI", "RTE",
"AX"
]:
assert not FLAGS.input_data_dir or FLAGS.tfds_params
else:
assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
FLAGS.tfds_params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment