Commit 5ad16f95 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 380296477
parent 19a49ae3
...@@ -135,18 +135,22 @@ class AxProcessor(DataProcessor): ...@@ -135,18 +135,22 @@ class AxProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( train_mnli_dataset = tfds.load(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") "glue/mnli", split="train", try_gcs=True).as_numpy_iterator()
return self._create_examples_tfds(train_mnli_dataset, "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( val_mnli_dataset = tfds.load(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") "glue/mnli", split="validation_matched",
try_gcs=True).as_numpy_iterator()
return self._create_examples_tfds(val_mnli_dataset, "validation")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( test_ax_dataset = tfds.load(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") "glue/ax", split="test", try_gcs=True).as_numpy_iterator()
return self._create_examples_tfds(test_ax_dataset, "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -157,24 +161,20 @@ class AxProcessor(DataProcessor): ...@@ -157,24 +161,20 @@ class AxProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "AX" return "AX"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, dataset, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
text_a_index = 1 if set_type == "test" else 8
text_b_index = 2 if set_type == "test" else 9
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
# Skip header. guid = "%s-%s" % (set_type, i)
if i == 0: label = "contradiction"
continue text_a = self.process_text_fn(example["hypothesis"])
guid = "%s-%s" % (set_type, self.process_text_fn(line[0])) text_b = self.process_text_fn(example["premise"])
text_a = self.process_text_fn(line[text_a_index]) if set_type != "test":
text_b = self.process_text_fn(line[text_b_index]) label = self.get_labels()[example["label"]]
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples return examples
...@@ -264,34 +264,28 @@ class MnliProcessor(DataProcessor): ...@@ -264,34 +264,28 @@ class MnliProcessor(DataProcessor):
mnli_type="matched", mnli_type="matched",
process_text_fn=tokenization.convert_to_unicode): process_text_fn=tokenization.convert_to_unicode):
super(MnliProcessor, self).__init__(process_text_fn) super(MnliProcessor, self).__init__(process_text_fn)
self.dataset = tfds.load("glue/mnli", try_gcs=True)
if mnli_type not in ("matched", "mismatched"): if mnli_type not in ("matched", "mismatched"):
raise ValueError("Invalid `mnli_type`: %s" % mnli_type) raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
self.mnli_type = mnli_type self.mnli_type = mnli_type
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
if self.mnli_type == "matched": if self.mnli_type == "matched":
return self._create_examples( return self._create_examples_tfds("validation_matched")
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
else: else:
return self._create_examples( return self._create_examples_tfds("validation_mismatched")
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_mismatched")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
if self.mnli_type == "matched": if self.mnli_type == "matched":
return self._create_examples( return self._create_examples_tfds("test_matched")
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
else: else:
return self._create_examples( return self._create_examples_tfds("test_mismatched")
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -302,21 +296,22 @@ class MnliProcessor(DataProcessor): ...@@ -302,21 +296,22 @@ class MnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "MNLI" return "MNLI"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/mnli", split=set_type, try_gcs=True).as_numpy_iterator()
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
if i == 0: guid = "%s-%s" % (set_type, i)
continue label = "contradiction"
guid = "%s-%s" % (set_type, self.process_text_fn(line[0])) text_a = self.process_text_fn(example["hypothesis"])
text_a = self.process_text_fn(line[8]) text_b = self.process_text_fn(example["premise"])
text_b = self.process_text_fn(line[9]) if set_type != "test":
if set_type == "test": label = self.get_labels()[example["label"]]
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples return examples
...@@ -325,18 +320,15 @@ class MrpcProcessor(DataProcessor): ...@@ -325,18 +320,15 @@ class MrpcProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("validation")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -347,21 +339,22 @@ class MrpcProcessor(DataProcessor): ...@@ -347,21 +339,22 @@ class MrpcProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "MRPC" return "MRPC"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/mrpc", split=set_type, try_gcs=True).as_numpy_iterator()
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
if i == 0:
continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[3]) label = "0"
text_b = self.process_text_fn(line[4]) text_a = self.process_text_fn(example["sentence1"])
if set_type == "test": text_b = self.process_text_fn(example["sentence2"])
label = "0" if set_type != "test":
else: label = str(example["label"])
label = self.process_text_fn(line[0])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples return examples
...@@ -449,18 +442,15 @@ class QnliProcessor(DataProcessor): ...@@ -449,18 +442,15 @@ class QnliProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("validation")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -471,23 +461,22 @@ class QnliProcessor(DataProcessor): ...@@ -471,23 +461,22 @@ class QnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "QNLI" return "QNLI"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/qnli", split=set_type, try_gcs=True).as_numpy_iterator()
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
if i == 0: guid = "%s-%s" % (set_type, i)
continue label = "entailment"
guid = "%s-%s" % (set_type, 1) text_a = self.process_text_fn(example["question"])
if set_type == "test": text_b = self.process_text_fn(example["sentence"])
text_a = tokenization.convert_to_unicode(line[1]) if set_type != "test":
text_b = tokenization.convert_to_unicode(line[2]) label = self.get_labels()[example["label"]]
label = "entailment"
else:
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples return examples
...@@ -496,18 +485,15 @@ class QqpProcessor(DataProcessor): ...@@ -496,18 +485,15 @@ class QqpProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("validation")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -518,27 +504,22 @@ class QqpProcessor(DataProcessor): ...@@ -518,27 +504,22 @@ class QqpProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "QQP" return "QQP"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/qqp", split=set_type, try_gcs=True).as_numpy_iterator()
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
if i == 0: guid = "%s-%s" % (set_type, i)
continue label = "0"
guid = "%s-%s" % (set_type, line[0]) text_a = self.process_text_fn(example["question1"])
if set_type == "test": text_b = self.process_text_fn(example["question2"])
text_a = line[1] if set_type != "test":
text_b = line[2] label = str(example["label"])
label = "0"
else:
# There appear to be some garbage lines in the train dataset.
try:
text_a = line[3]
text_b = line[4]
label = line[5]
except IndexError:
continue
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples return examples
...@@ -547,18 +528,15 @@ class RteProcessor(DataProcessor): ...@@ -547,18 +528,15 @@ class RteProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("validation")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -571,21 +549,22 @@ class RteProcessor(DataProcessor): ...@@ -571,21 +549,22 @@ class RteProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "RTE" return "RTE"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/rte", split=set_type, try_gcs=True).as_numpy_iterator()
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
if i == 0:
continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1]) label = "entailment"
text_b = tokenization.convert_to_unicode(line[2]) text_a = self.process_text_fn(example["sentence1"])
if set_type == "test": text_b = self.process_text_fn(example["sentence2"])
label = "entailment" if set_type != "test":
else: label = self.get_labels()[example["label"]]
label = tokenization.convert_to_unicode(line[3])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples return examples
...@@ -594,18 +573,15 @@ class SstProcessor(DataProcessor): ...@@ -594,18 +573,15 @@ class SstProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("validation")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -616,21 +592,20 @@ class SstProcessor(DataProcessor): ...@@ -616,21 +592,20 @@ class SstProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "SST-2" return "SST-2"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/sst2", split=set_type, try_gcs=True).as_numpy_iterator()
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
if i == 0:
continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
if set_type == "test": label = "0"
text_a = tokenization.convert_to_unicode(line[1]) text_a = self.process_text_fn(example["sentence"])
label = "0" if set_type != "test":
else: label = str(example["label"])
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=None, label=label, weight=None))
return examples return examples
...@@ -645,18 +620,33 @@ class StsBProcessor(DataProcessor): ...@@ -645,18 +620,33 @@ class StsBProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("validation")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/stsb", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
label = 0.0
text_a = self.process_text_fn(example["sentence1"])
text_b = self.process_text_fn(example["sentence2"])
if set_type != "test":
label = self.label_type(example["label"])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -667,23 +657,6 @@ class StsBProcessor(DataProcessor): ...@@ -667,23 +657,6 @@ class StsBProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "STS-B" return "STS-B"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[7])
text_b = tokenization.convert_to_unicode(line[8])
if set_type == "test":
label = 0.0
else:
label = self.label_type(tokenization.convert_to_unicode(line[9]))
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class TfdsProcessor(DataProcessor): class TfdsProcessor(DataProcessor):
"""Processor for generic text classification and regression TFDS data set. """Processor for generic text classification and regression TFDS data set.
...@@ -818,18 +791,15 @@ class WnliProcessor(DataProcessor): ...@@ -818,18 +791,15 @@ class WnliProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("validation")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -840,21 +810,22 @@ class WnliProcessor(DataProcessor): ...@@ -840,21 +810,22 @@ class WnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "WNLI" return "WNLI"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"glue/wnli", split=set_type, try_gcs=True).as_numpy_iterator()
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
if i == 0:
continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1]) label = "0"
text_b = tokenization.convert_to_unicode(line[2]) text_a = self.process_text_fn(example["sentence1"])
if set_type == "test": text_b = self.process_text_fn(example["sentence2"])
label = "0" if set_type != "test":
else: label = str(example["label"])
label = tokenization.convert_to_unicode(line[3])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=None))
return examples return examples
......
...@@ -173,8 +173,14 @@ flags.DEFINE_string( ...@@ -173,8 +173,14 @@ flags.DEFINE_string(
def generate_classifier_dataset(): def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data.""" """Generates classifier dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.classification_task_name or if FLAGS.classification_task_name in [
FLAGS.tfds_params) "COLA", "WNLI", "SST-2", "MRPC", "QQP", "STS-B", "MNLI", "QNLI", "RTE",
"AX"
]:
assert not FLAGS.input_data_dir or FLAGS.tfds_params
else:
assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
FLAGS.tfds_params)
if FLAGS.tokenization == "WordPiece": if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment