Commit 487af3fa authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 379108917
parent e1df7597
...@@ -181,20 +181,21 @@ class AxProcessor(DataProcessor): ...@@ -181,20 +181,21 @@ class AxProcessor(DataProcessor):
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the CoLA data set (GLUE version)."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(ColaProcessor, self).__init__(process_text_fn)
self.dataset = tfds.load("glue/cola", try_gcs=True)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("validation")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples_tfds("test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -205,22 +206,19 @@ class ColaProcessor(DataProcessor): ...@@ -205,22 +206,19 @@ class ColaProcessor(DataProcessor):
"""See base class.""" """See base class."""
return "COLA" return "COLA"
def _create_examples(self, lines, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = self.dataset[set_type].as_numpy_iterator()
examples = [] examples = []
for i, line in enumerate(lines): for i, example in enumerate(dataset):
# Only the test set has a header.
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
if set_type == "test": label = "0"
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(example["sentence"])
label = "0" if set_type != "test":
else: label = str(example["label"])
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(
guid=guid, text_a=text_a, text_b=None, label=label, weight=None))
return examples return examples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment