Commit c73c0127 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 381367878
parent 88b2a354
...@@ -129,6 +129,30 @@ class DataProcessor(object): ...@@ -129,6 +129,30 @@ class DataProcessor(object):
lines.append(json.loads(json_str)) lines.append(json.loads(json_str))
return lines return lines
def featurize_example(self, *kargs, **kwargs):
"""Converts a single `InputExample` into a single `InputFeatures`."""
return convert_single_example(*kargs, **kwargs)
class DefaultGLUEDataProcessor(DataProcessor):
"""Processor for the SuperGLUE dataset."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
raise NotImplementedError()
class AxProcessor(DataProcessor): class AxProcessor(DataProcessor):
"""Processor for the AX dataset (GLUE diagnostics dataset).""" """Processor for the AX dataset (GLUE diagnostics dataset)."""
...@@ -178,21 +202,9 @@ class AxProcessor(DataProcessor): ...@@ -178,21 +202,9 @@ class AxProcessor(DataProcessor):
return examples return examples
class ColaProcessor(DataProcessor): class ColaProcessor(DefaultGLUEDataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the CoLA data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
...@@ -315,21 +327,9 @@ class MnliProcessor(DataProcessor): ...@@ -315,21 +327,9 @@ class MnliProcessor(DataProcessor):
return examples return examples
class MrpcProcessor(DataProcessor): class MrpcProcessor(DefaultGLUEDataProcessor):
"""Processor for the MRPC data set (GLUE version).""" """Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
...@@ -437,21 +437,9 @@ class PawsxProcessor(DataProcessor): ...@@ -437,21 +437,9 @@ class PawsxProcessor(DataProcessor):
return "XTREME-PAWS-X" return "XTREME-PAWS-X"
class QnliProcessor(DataProcessor): class QnliProcessor(DefaultGLUEDataProcessor):
"""Processor for the QNLI data set (GLUE version).""" """Processor for the QNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["entailment", "not_entailment"] return ["entailment", "not_entailment"]
...@@ -480,21 +468,9 @@ class QnliProcessor(DataProcessor): ...@@ -480,21 +468,9 @@ class QnliProcessor(DataProcessor):
return examples return examples
class QqpProcessor(DataProcessor): class QqpProcessor(DefaultGLUEDataProcessor):
"""Processor for the QQP data set (GLUE version).""" """Processor for the QQP data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
...@@ -523,21 +499,9 @@ class QqpProcessor(DataProcessor): ...@@ -523,21 +499,9 @@ class QqpProcessor(DataProcessor):
return examples return examples
class RteProcessor(DataProcessor): class RteProcessor(DefaultGLUEDataProcessor):
"""Processor for the RTE data set (GLUE version).""" """Processor for the RTE data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
# All datasets are converted to 2-class split, where for 3-class datasets we # All datasets are converted to 2-class split, where for 3-class datasets we
...@@ -568,21 +532,9 @@ class RteProcessor(DataProcessor): ...@@ -568,21 +532,9 @@ class RteProcessor(DataProcessor):
return examples return examples
class SstProcessor(DataProcessor): class SstProcessor(DefaultGLUEDataProcessor):
"""Processor for the SST-2 data set (GLUE version).""" """Processor for the SST-2 data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
...@@ -609,7 +561,7 @@ class SstProcessor(DataProcessor): ...@@ -609,7 +561,7 @@ class SstProcessor(DataProcessor):
return examples return examples
class StsBProcessor(DataProcessor): class StsBProcessor(DefaultGLUEDataProcessor):
"""Processor for the STS-B data set (GLUE version).""" """Processor for the STS-B data set (GLUE version)."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode): def __init__(self, process_text_fn=tokenization.convert_to_unicode):
...@@ -618,18 +570,6 @@ class StsBProcessor(DataProcessor): ...@@ -618,18 +570,6 @@ class StsBProcessor(DataProcessor):
self.label_type = float self.label_type = float
self._labels = None self._labels = None
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def _create_examples_tfds(self, set_type): def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
...@@ -786,21 +726,9 @@ class TfdsProcessor(DataProcessor): ...@@ -786,21 +726,9 @@ class TfdsProcessor(DataProcessor):
return examples return examples
class WnliProcessor(DataProcessor): class WnliProcessor(DefaultGLUEDataProcessor):
"""Processor for the WNLI data set (GLUE version).""" """Processor for the WNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
...@@ -1282,27 +1210,7 @@ class AXgProcessor(DataProcessor): ...@@ -1282,27 +1210,7 @@ class AXgProcessor(DataProcessor):
return examples return examples
class SuperGLUEDataProcessor(DataProcessor): class BoolQProcessor(DefaultGLUEDataProcessor):
"""Processor for the SuperGLUE dataset."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("validation")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples_tfds("test")
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
raise NotImplementedError()
class BoolQProcessor(SuperGLUEDataProcessor):
"""Processor for the BoolQ dataset (SuperGLUE diagnostics dataset).""" """Processor for the BoolQ dataset (SuperGLUE diagnostics dataset)."""
def get_labels(self): def get_labels(self):
...@@ -1331,7 +1239,7 @@ class BoolQProcessor(SuperGLUEDataProcessor): ...@@ -1331,7 +1239,7 @@ class BoolQProcessor(SuperGLUEDataProcessor):
return examples return examples
class CBProcessor(SuperGLUEDataProcessor): class CBProcessor(DefaultGLUEDataProcessor):
"""Processor for the CB dataset (SuperGLUE diagnostics dataset).""" """Processor for the CB dataset (SuperGLUE diagnostics dataset)."""
def get_labels(self): def get_labels(self):
...@@ -1360,7 +1268,7 @@ class CBProcessor(SuperGLUEDataProcessor): ...@@ -1360,7 +1268,7 @@ class CBProcessor(SuperGLUEDataProcessor):
return examples return examples
class SuperGLUERTEProcessor(SuperGLUEDataProcessor): class SuperGLUERTEProcessor(DefaultGLUEDataProcessor):
"""Processor for the RTE dataset (SuperGLUE version).""" """Processor for the RTE dataset (SuperGLUE version)."""
def get_labels(self): def get_labels(self):
...@@ -1396,7 +1304,8 @@ def file_based_convert_examples_to_features(examples, ...@@ -1396,7 +1304,8 @@ def file_based_convert_examples_to_features(examples,
max_seq_length, max_seq_length,
tokenizer, tokenizer,
output_file, output_file,
label_type=None): label_type=None,
featurize_fn=None):
"""Convert a set of `InputExample`s to a TFRecord file.""" """Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file)) tf.io.gfile.makedirs(os.path.dirname(output_file))
...@@ -1406,8 +1315,12 @@ def file_based_convert_examples_to_features(examples, ...@@ -1406,8 +1315,12 @@ def file_based_convert_examples_to_features(examples,
if ex_index % 10000 == 0: if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples)) logging.info("Writing example %d of %d", ex_index, len(examples))
feature = convert_single_example(ex_index, example, label_list, if featurize_fn:
max_seq_length, tokenizer) feature = featurize_fn(ex_index, example, label_list, max_seq_length,
tokenizer)
else:
feature = convert_single_example(ex_index, example, label_list,
max_seq_length, tokenizer)
def create_int_feature(values): def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
...@@ -1496,7 +1409,8 @@ def generate_tf_record_from_data_file(processor, ...@@ -1496,7 +1409,8 @@ def generate_tf_record_from_data_file(processor,
file_based_convert_examples_to_features(train_input_data_examples, file_based_convert_examples_to_features(train_input_data_examples,
label_list, max_seq_length, label_list, max_seq_length,
tokenizer, train_data_output_path, tokenizer, train_data_output_path,
label_type) label_type,
processor.featurize_example)
num_training_data = len(train_input_data_examples) num_training_data = len(train_input_data_examples)
if eval_data_output_path: if eval_data_output_path:
...@@ -1504,7 +1418,8 @@ def generate_tf_record_from_data_file(processor, ...@@ -1504,7 +1418,8 @@ def generate_tf_record_from_data_file(processor,
file_based_convert_examples_to_features(eval_input_data_examples, file_based_convert_examples_to_features(eval_input_data_examples,
label_list, max_seq_length, label_list, max_seq_length,
tokenizer, eval_data_output_path, tokenizer, eval_data_output_path,
label_type) label_type,
processor.featurize_example)
meta_data = { meta_data = {
"processor_type": processor.get_processor_name(), "processor_type": processor.get_processor_name(),
...@@ -1518,13 +1433,15 @@ def generate_tf_record_from_data_file(processor, ...@@ -1518,13 +1433,15 @@ def generate_tf_record_from_data_file(processor,
for language, examples in test_input_data_examples.items(): for language, examples in test_input_data_examples.items():
file_based_convert_examples_to_features( file_based_convert_examples_to_features(
examples, label_list, max_seq_length, tokenizer, examples, label_list, max_seq_length, tokenizer,
test_data_output_path.format(language), label_type) test_data_output_path.format(language), label_type,
processor.featurize_example)
meta_data["test_{}_data_size".format(language)] = len(examples) meta_data["test_{}_data_size".format(language)] = len(examples)
else: else:
file_based_convert_examples_to_features(test_input_data_examples, file_based_convert_examples_to_features(test_input_data_examples,
label_list, max_seq_length, label_list, max_seq_length,
tokenizer, test_data_output_path, tokenizer, test_data_output_path,
label_type) label_type,
processor.featurize_example)
meta_data["test_data_size"] = len(test_input_data_examples) meta_data["test_data_size"] = len(test_input_data_examples)
if is_regression: if is_regression:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment