Commit d697041a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 301915584
parent 19d930c3
...@@ -24,6 +24,7 @@ import os ...@@ -24,6 +24,7 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
...@@ -386,6 +387,99 @@ class QnliProcessor(DataProcessor): ...@@ -386,6 +387,99 @@ class QnliProcessor(DataProcessor):
return examples return examples
class TfdsProcessor(DataProcessor):
"""Processor for generic text classification TFDS data set.
The TFDS parameters are expected to be provided in the tfds_params string, in
a comma-separated list of parameter assignments.
Examples:
tfds_params="dataset=scicite,text_key=string"
tfds_params="dataset=imdb_reviews,test_split=,dev_split=test"
tfds_params="dataset=glue/cola,text_key=sentence"
tfds_params="dataset=glue/sst2,text_key=sentence"
tfds_params="dataset=glue/qnli,text_key=question,text_b_key=sentence"
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
Possible parameters (please refer to the documentation of Tensorflow Datasets
(TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number).
data_dir: Optional TFDS source root directory.
train_split: Name of the train split (defaults to `train`).
dev_split: Name of the dev split (defaults to `validation`).
test_split: Name of the test split (defaults to `test`).
text_key: Key of the text_a feature (defaults to `text`).
text_b_key: Key of the second text feature if available.
label_key: Key of the label feature (defaults to `label`).
test_text_key: Key of the text feature to use in test set.
test_text_b_key: Key of the second text feature to use in test set.
test_label: String to be used as the label for all test examples.
"""
def __init__(self, tfds_params,
process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir,
with_info=True)
self._labels = list(range(info.features[self.label_key].num_classes))
def _process_tfds_params_str(self, params_str):
"""Extracts TFDS parameters from a comma-separated assignements string."""
tuples = [x.split("=") for x in params_str.split(",")]
d = {k.strip(): v.strip() for k, v in tuples}
self.dataset_name = d["dataset"] # Required.
self.data_dir = d.get("data_dir", None)
self.train_split = d.get("train_split", "train")
self.dev_split = d.get("dev_split", "validation")
self.test_split = d.get("test_split", "test")
self.text_key = d.get("text_key", "text")
self.text_b_key = d.get("text_b_key", None)
self.label_key = d.get("label_key", "label")
self.test_text_key = d.get("test_text_key", self.text_key)
self.test_text_b_key = d.get("test_text_b_key", self.text_b_key)
self.test_label = d.get("test_label", "test_example")
def get_train_examples(self, data_dir):
assert data_dir is None
return self._create_examples(self.train_split, "train")
def get_dev_examples(self, data_dir):
assert data_dir is None
return self._create_examples(self.dev_split, "dev")
def get_test_examples(self, data_dir):
assert data_dir is None
return self._create_examples(self.test_split, "test")
def get_labels(self):
return self._labels
def get_processor_name(self):
return "TFDS_" + self.dataset_name
def _create_examples(self, split_name, set_type):
"""Creates examples for the training and dev sets."""
if split_name not in self.dataset:
raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator()
examples = []
text_b = None
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = self.process_text_fn(example[self.test_text_key])
if self.test_text_b_key:
text_b = self.process_text_fn(example[self.test_text_b_key])
label = self.test_label
else:
text_a = self.process_text_fn(example[self.text_key])
if self.text_b_key:
text_b = self.process_text_fn(example[self.text_b_key])
label = int(example[self.label_key])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def convert_single_example(ex_index, example, label_list, max_seq_length, def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer): tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`.""" """Converts a single `InputExample` into a single `InputFeatures`."""
......
...@@ -104,22 +104,16 @@ flags.DEFINE_enum( ...@@ -104,22 +104,16 @@ flags.DEFINE_enum(
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, " "or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"while ALBERT uses sentence_piece tokenizer.") "while ALBERT uses sentence_piece tokenizer.")
flags.DEFINE_string("tfds_params", "",
"Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation).")
def generate_classifier_dataset(): def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data.""" """Generates classifier dataset and returns input meta data."""
assert FLAGS.input_data_dir and FLAGS.classification_task_name assert (FLAGS.input_data_dir and FLAGS.classification_task_name
or FLAGS.tfds_params)
processors = {
"cola": classifier_data_lib.ColaProcessor,
"mnli": classifier_data_lib.MnliProcessor,
"mrpc": classifier_data_lib.MrpcProcessor,
"qnli": classifier_data_lib.QnliProcessor,
"sst-2": classifier_data_lib.SstProcessor,
"xnli": classifier_data_lib.XnliProcessor,
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
...@@ -131,14 +125,38 @@ def generate_classifier_dataset(): ...@@ -131,14 +125,38 @@ def generate_classifier_dataset():
processor_text_fn = functools.partial( processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case) tokenization.preprocess_text, lower=FLAGS.do_lower_case)
processor = processors[task_name](processor_text_fn) if FLAGS.tfds_params:
return classifier_data_lib.generate_tf_record_from_data_file( processor = classifier_data_lib.TfdsProcessor(
processor, tfds_params=FLAGS.tfds_params,
FLAGS.input_data_dir, process_text_fn=processor_text_fn)
tokenizer, return classifier_data_lib.generate_tf_record_from_data_file(
train_data_output_path=FLAGS.train_data_output_path, processor,
eval_data_output_path=FLAGS.eval_data_output_path, None,
max_seq_length=FLAGS.max_seq_length) tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
max_seq_length=FLAGS.max_seq_length)
else:
processors = {
"cola": classifier_data_lib.ColaProcessor,
"mnli": classifier_data_lib.MnliProcessor,
"mrpc": classifier_data_lib.MrpcProcessor,
"qnli": classifier_data_lib.QnliProcessor,
"sst-2": classifier_data_lib.SstProcessor,
"xnli": classifier_data_lib.XnliProcessor,
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name](processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
FLAGS.input_data_dir,
tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
max_seq_length=FLAGS.max_seq_length)
def generate_squad_dataset(): def generate_squad_dataset():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment