Commit af924a4c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 381152422
parent a0494c94
......@@ -1287,20 +1287,17 @@ class SuperGLUEDataProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "train.jsonl")), "train")
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "val.jsonl")), "dev")
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "test.jsonl")), "test")
return self._create_examples_tfds("test")
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
raise NotImplementedError()
......@@ -1317,17 +1314,18 @@ class BoolQProcessor(SuperGLUEDataProcessor):
"""See base class."""
return "BoolQ"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"super_glue/boolq", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for line in lines:
guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
text_a = self.process_text_fn(line["question"])
text_b = self.process_text_fn(line["passage"])
if set_type == "test":
label = "False"
else:
label = str(line["label"])
for example in dataset:
guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
text_a = self.process_text_fn(example["question"])
text_b = self.process_text_fn(example["passage"])
label = "False"
if set_type != "test":
label = self.get_labels()[example["label"]]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -1345,17 +1343,18 @@ class CBProcessor(SuperGLUEDataProcessor):
"""See base class."""
return "CB"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
dataset = tfds.load(
"super_glue/cb", split=set_type, try_gcs=True).as_numpy_iterator()
examples = []
for line in lines:
guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
text_a = self.process_text_fn(line["premise"])
text_b = self.process_text_fn(line["hypothesis"])
if set_type == "test":
label = "entailment"
else:
label = self.process_text_fn(line["label"])
for example in dataset:
guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
text_a = self.process_text_fn(example["premise"])
text_b = self.process_text_fn(example["hypothesis"])
label = "entailment"
if set_type != "test":
label = self.get_labels()[example["label"]]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -1375,17 +1374,18 @@ class SuperGLUERTEProcessor(SuperGLUEDataProcessor):
"""See base class."""
return "RTESuperGLUE"
def _create_examples(self, lines, set_type):
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line["premise"])
text_b = self.process_text_fn(line["hypothesis"])
if set_type == "test":
label = "entailment"
else:
label = self.process_text_fn(line["label"])
dataset = tfds.load(
"super_glue/rte", split=set_type, try_gcs=True).as_numpy_iterator()
for example in dataset:
guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
text_a = self.process_text_fn(example["premise"])
text_b = self.process_text_fn(example["hypothesis"])
label = "entailment"
if set_type != "test":
label = self.get_labels()[example["label"]]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for third_party.tensorflow_models.official.nlp.data.classifier_data_lib."""
import os
import tempfile
from absl.testing import parameterized
import tensorflow as tf
import tensorflow_datasets as tfds
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
def decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
return tf.io.parse_single_example(record, name_to_features)
class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(BertClassifierLibTest, self).setUp()
self.model_dir = self.get_temp_dir()
self.processors = {
"CB": classifier_data_lib.CBProcessor,
"SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor,
"BOOLQ": classifier_data_lib.BoolQProcessor,
}
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
]
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
]).encode("utf-8"))
vocab_file = vocab_writer.name
self.tokenizer = tokenization.FullTokenizer(vocab_file)
@parameterized.parameters(
{"task_type": "CB"},
{"task_type": "BOOLQ"},
{"task_type": "SUPERGLUE-RTE"},
)
def test_generate_dataset_from_tfds_processor(self, task_type):
with tfds.testing.mock_data(num_examples=5):
output_path = os.path.join(self.model_dir, task_type)
processor = self.processors[task_type]()
classifier_data_lib.generate_tf_record_from_data_file(
processor,
None,
self.tokenizer,
train_data_output_path=output_path,
eval_data_output_path=output_path,
test_data_output_path=output_path)
files = tf.io.gfile.glob(output_path)
self.assertNotEmpty(files)
train_dataset = tf.data.TFRecordDataset(output_path)
seq_length = 128
label_type = tf.int64
name_to_features = {
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.io.FixedLenFeature([], label_type),
}
train_dataset = train_dataset.map(
lambda record: decode_record(record, name_to_features))
# If data is retrieved without error, then all requirements
# including data type/shapes are met.
_ = next(iter(train_dataset))
if __name__ == "__main__":
tf.test.main()
......@@ -175,7 +175,7 @@ def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
if FLAGS.classification_task_name in [
"COLA", "WNLI", "SST-2", "MRPC", "QQP", "STS-B", "MNLI", "QNLI", "RTE",
"AX"
"AX", "SUPERGLUE-RTE", "CB", "BoolQ"
]:
assert not FLAGS.input_data_dir or FLAGS.tfds_params
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment