Commit a2ebafbb authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 356587847
parent ae871f41
...@@ -14,12 +14,16 @@ ...@@ -14,12 +14,16 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Loads dataset for the sentence prediction (classification) task.""" """Loads dataset for the sentence prediction (classification) task."""
from typing import Mapping, Optional from typing import List, Mapping, Optional
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub
from official.common import dataset_fn
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import input_reader from official.core import input_reader
from official.nlp import modeling
from official.nlp.data import data_loader from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -89,3 +93,152 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -89,3 +93,152 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse) params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context) return reader.read(input_context)
@dataclasses.dataclass
class SentencePredictionTextDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task with raw text."""
# Either set `input_path`...
input_path: str = ''
# Either `int` or `float`.
label_type: str = 'int'
# ...or `tfds_name` and `tfds_split` to specify input.
tfds_name: str = ''
tfds_split: str = ''
# The name of the text feature fields. The text features will be
# concatenated in order.
text_fields: Optional[List[str]] = None
label_field: str = 'label'
global_batch_size: int = 32
seq_length: int = 128
is_training: bool = True
# Either build preprocessing with Python code by specifying these values
# for modeling.layers.BertTokenizer()/SentencepieceTokenizer()....
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
# file if tokenization is SentencePiece.
vocab_file: str = ''
lower_case: bool = True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url: str = ''
# Either tfrecord or sstsable or recordio.
file_type: str = 'tfrecord'
class TextProcessor(tf.Module):
"""Text features processing for sentence prediction task."""
def __init__(self,
seq_length: int,
vocab_file: Optional[str] = None,
tokenization: Optional[str] = None,
lower_case: Optional[bool] = True,
preprocessing_hub_module_url: Optional[str] = None):
if preprocessing_hub_module_url:
self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
self._tokenizer = self._preprocessing_hub_module.tokenize
def set_shape(t):
# Before TF2.4, the sequence length dimension loaded from the
# preprocessing hub module is None, so we recover the shape here.
# TODO(b/157636658): Remove once TF2.4 is released and being used.
t.set_shape([None, seq_length])
return t
def pack_inputs_fn(inputs):
result = self._preprocessing_hub_module.bert_pack_inputs(
inputs, seq_length=seq_length)
result = tf.nest.map_structure(set_shape, result)
return result
self._pack_inputs = pack_inputs_fn
return
if tokenization == 'WordPiece':
self._tokenizer = modeling.layers.BertTokenizer(
vocab_file=vocab_file, lower_case=lower_case)
elif tokenization == 'SentencePiece':
self._tokenizer = modeling.layers.SentencepieceTokenizer(
model_file_path=vocab_file, lower_case=lower_case,
strip_diacritics=True) # Strip diacritics to follow ALBERT model
else:
raise ValueError('Unsupported tokenization: %s' % tokenization)
self._pack_inputs = modeling.layers.BertPackInputs(
seq_length=seq_length,
special_tokens_dict=self._tokenizer.get_special_tokens_dict())
def __call__(self, segments):
segments = [self._tokenizer(s) for s in segments]
# BertTokenizer returns a RaggedTensor with shape [batch, word, subword],
# and SentencepieceTokenizer returns a RaggedTensor with shape
# [batch, sentencepiece],
segments = [
tf.cast(x.merge_dims(1, -1) if x.shape.rank > 2 else x, tf.int32)
for x in segments
]
return self._pack_inputs(segments)
@data_loader_factory.register_data_loader_cls(SentencePredictionTextDataConfig)
class SentencePredictionTextDataLoader(data_loader.DataLoader):
"""Loads dataset with raw text for sentence prediction task."""
def __init__(self, params):
if bool(params.tfds_name) != bool(params.tfds_split):
raise ValueError('`tfds_name` and `tfds_split` should be specified or '
'unspecified at the same time.')
if bool(params.tfds_name) == bool(params.input_path):
raise ValueError('Must specify either `tfds_name` and `tfds_split` '
'or `input_path`.')
if not params.text_fields:
raise ValueError('Unexpected empty text fields.')
if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
raise ValueError('Must specify exactly one of vocab_file (with matching '
'lower_case flag) or preprocessing_hub_module_url.')
self._params = params
self._text_fields = params.text_fields
self._label_field = params.label_field
self._label_type = params.label_type
self._text_processor = TextProcessor(
seq_length=params.seq_length,
vocab_file=params.vocab_file,
tokenization=params.tokenization,
lower_case=params.lower_case,
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
"""Berts preprocess."""
segments = [record[x] for x in self._text_fields]
model_inputs = self._text_processor(segments)
y = record[self._label_field]
return model_inputs, y
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {}
for text_field in self._text_fields:
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
label_type = LABEL_TYPES_MAP[self._label_type]
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
decoder_fn=self._decode if self._params.input_path else None,
params=self._params,
postprocess_fn=self._bert_preprocess)
return reader.read(input_context)
...@@ -20,10 +20,11 @@ from absl.testing import parameterized ...@@ -20,10 +20,11 @@ from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.data import sentence_prediction_dataloader from sentencepiece import SentencePieceTrainer
from official.nlp.data import sentence_prediction_dataloader as loader
def _create_fake_dataset(output_path, seq_length, label_type): def _create_fake_preprocessed_dataset(output_path, seq_length, label_type):
"""Creates a fake dataset.""" """Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path) writer = tf.io.TFRecordWriter(output_path)
...@@ -54,6 +55,70 @@ def _create_fake_dataset(output_path, seq_length, label_type): ...@@ -54,6 +55,70 @@ def _create_fake_dataset(output_path, seq_length, label_type):
writer.close() writer.close()
def _create_fake_raw_dataset(output_path, text_fields, label_type):
"""Creates a fake tf record file."""
writer = tf.io.TFRecordWriter(output_path)
def create_str_feature(value):
f = tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
return f
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
for _ in range(100):
features = {}
for text_field in text_fields:
features[text_field] = create_str_feature([b'hello world'])
if label_type == 'int':
features['label'] = create_int_feature([0])
elif label_type == 'float':
features['label'] = create_float_feature([0.5])
else:
raise ValueError('Unexpected label_type: %s' % label_type)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
def _create_fake_sentencepiece_model(output_dir):
vocab = ['a', 'b', 'c', 'd', 'e', 'abc', 'def', 'ABC', 'DEF']
model_prefix = os.path.join(output_dir, 'spm_model')
input_text_file_path = os.path.join(output_dir, 'train_input.txt')
with tf.io.gfile.GFile(input_text_file_path, 'w') as f:
f.write(' '.join(vocab + ['\n']))
# Add 7 more tokens: <pad>, <unk>, [CLS], [SEP], [MASK], <s>, </s>.
full_vocab_size = len(vocab) + 7
flags = dict(
model_prefix=model_prefix,
model_type='word',
input=input_text_file_path,
pad_id=0,
unk_id=1,
control_symbols='[CLS],[SEP],[MASK]',
vocab_size=full_vocab_size,
bos_id=full_vocab_size - 2,
eos_id=full_vocab_size - 1)
SentencePieceTrainer.Train(' '.join(
['--{}={}'.format(k, v) for k, v in flags.items()]))
return model_prefix + '.model'
def _create_fake_vocab_file(vocab_file_path):
tokens = ['[PAD]']
for i in range(1, 100):
tokens.append('[unused%d]' % i)
tokens.extend(['[UNK]', '[CLS]', '[SEP]', '[MASK]', 'hello', 'world'])
with tf.io.gfile.GFile(vocab_file_path, 'w') as outfile:
outfile.write('\n'.join(tokens))
class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase): class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('int', tf.int32), ('float', tf.float32)) @parameterized.parameters(('int', tf.int32), ('float', tf.float32))
...@@ -61,14 +126,13 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase): ...@@ -61,14 +126,13 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record') input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
batch_size = 10 batch_size = 10
seq_length = 128 seq_length = 128
_create_fake_dataset(input_path, seq_length, label_type) _create_fake_preprocessed_dataset(input_path, seq_length, label_type)
data_config = sentence_prediction_dataloader.SentencePredictionDataConfig( data_config = loader.SentencePredictionDataConfig(
input_path=input_path, input_path=input_path,
seq_length=seq_length, seq_length=seq_length,
global_batch_size=batch_size, global_batch_size=batch_size,
label_type=label_type) label_type=label_type)
dataset = sentence_prediction_dataloader.SentencePredictionDataLoader( dataset = loader.SentencePredictionDataLoader(data_config).load()
data_config).load()
features, labels = next(iter(dataset)) features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'], self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'],
features.keys()) features.keys())
...@@ -79,5 +143,108 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase): ...@@ -79,5 +143,108 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(labels.dtype, expected_label_type) self.assertEqual(labels.dtype, expected_label_type)
class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
parameterized.TestCase):
@parameterized.parameters(True, False)
def test_python_wordpiece_preprocessing(self, use_tfds):
batch_size = 10
seq_length = 256 # Non-default value.
lower_case = True
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
text_fields = ['sentence1', 'sentence2']
if not use_tfds:
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
_create_fake_vocab_file(vocab_file_path)
data_config = loader.SentencePredictionTextDataConfig(
input_path='' if use_tfds else tf_record_path,
tfds_name='glue/mrpc' if use_tfds else '',
tfds_split='train' if use_tfds else '',
text_fields=text_fields,
global_batch_size=batch_size,
seq_length=seq_length,
is_training=True,
lower_case=lower_case,
vocab_file=vocab_file_path)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
@parameterized.parameters(True, False)
def test_python_sentencepiece_preprocessing(self, use_tfds):
batch_size = 10
seq_length = 256 # Non-default value.
lower_case = True
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
text_fields = ['sentence1', 'sentence2']
if not use_tfds:
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
sp_model_file_path = _create_fake_sentencepiece_model(self.get_temp_dir())
data_config = loader.SentencePredictionTextDataConfig(
input_path='' if use_tfds else tf_record_path,
tfds_name='glue/mrpc' if use_tfds else '',
tfds_split='train' if use_tfds else '',
text_fields=text_fields,
global_batch_size=batch_size,
seq_length=seq_length,
is_training=True,
lower_case=lower_case,
tokenization='SentencePiece',
vocab_file=sp_model_file_path,
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
@parameterized.parameters(True, False)
def test_saved_model_preprocessing(self, use_tfds):
batch_size = 10
seq_length = 256 # Non-default value.
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
text_fields = ['sentence1', 'sentence2']
if not use_tfds:
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='float')
vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
_create_fake_vocab_file(vocab_file_path)
data_config = loader.SentencePredictionTextDataConfig(
input_path='' if use_tfds else tf_record_path,
tfds_name='glue/mrpc' if use_tfds else '',
tfds_split='train' if use_tfds else '',
text_fields=text_fields,
global_batch_size=batch_size,
seq_length=seq_length,
is_training=True,
preprocessing_hub_module_url=(
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'),
label_type='int' if use_tfds else 'float',
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment