"vscode:/vscode.git/clone" did not exist on "8433780efb9c78ab3136bb8de8ed104284664438"
Commit e10b29f4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 330621236
parent 636ca66f
...@@ -60,6 +60,10 @@ flags.DEFINE_bool( ...@@ -60,6 +60,10 @@ flags.DEFINE_bool(
"gzip_compress", False, "gzip_compress", False,
"Whether to use `GZIP` compress option to get compressed TFRecord files.") "Whether to use `GZIP` compress option to get compressed TFRecord files.")
flags.DEFINE_bool(
"use_v2_feature_names", False,
"Whether to use the feature names consistent with the models.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20, flags.DEFINE_integer("max_predictions_per_seq", 20,
...@@ -147,9 +151,14 @@ def write_instance_to_example_files(instances, tokenizer, max_seq_length, ...@@ -147,9 +151,14 @@ def write_instance_to_example_files(instances, tokenizer, max_seq_length,
next_sentence_label = 1 if instance.is_random_next else 0 next_sentence_label = 1 if instance.is_random_next else 0
features = collections.OrderedDict() features = collections.OrderedDict()
if FLAGS.use_v2_feature_names:
features["input_word_ids"] = create_int_feature(input_ids)
features["input_type_ids"] = create_int_feature(segment_ids)
else:
features["input_ids"] = create_int_feature(input_ids) features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids) features["segment_ids"] = create_int_feature(segment_ids)
features["input_mask"] = create_int_feature(input_mask)
features["masked_lm_positions"] = create_int_feature(masked_lm_positions) features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
features["masked_lm_ids"] = create_int_feature(masked_lm_ids) features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
features["masked_lm_weights"] = create_float_feature(masked_lm_weights) features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
......
...@@ -35,6 +35,12 @@ class BertPretrainDataConfig(cfg.DataConfig): ...@@ -35,6 +35,12 @@ class BertPretrainDataConfig(cfg.DataConfig):
max_predictions_per_seq: int = 76 max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True use_next_sentence_label: bool = True
use_position_id: bool = False use_position_id: bool = False
# Historically, BERT implementations take `input_ids` and `segment_ids` as
# feature names. Inside the TF Model Garden implementation, the Keras model
# inputs are set as `input_word_ids` and `input_type_ids`. When
# v2_feature_names is True, the data loader assumes the tf.Examples use
# `input_word_ids` and `input_type_ids` as keys.
use_v2_feature_names: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig) @data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
...@@ -56,12 +62,8 @@ class BertPretrainDataLoader(data_loader.DataLoader): ...@@ -56,12 +62,8 @@ class BertPretrainDataLoader(data_loader.DataLoader):
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
name_to_features = { name_to_features = {
'input_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': 'input_mask':
tf.io.FixedLenFeature([self._seq_length], tf.int64), tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'masked_lm_positions': 'masked_lm_positions':
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64), tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
'masked_lm_ids': 'masked_lm_ids':
...@@ -69,6 +71,16 @@ class BertPretrainDataLoader(data_loader.DataLoader): ...@@ -69,6 +71,16 @@ class BertPretrainDataLoader(data_loader.DataLoader):
'masked_lm_weights': 'masked_lm_weights':
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32), tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
} }
if self._params.use_v2_feature_names:
name_to_features.update({
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
})
else:
name_to_features.update({
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
})
if self._use_next_sentence_label: if self._use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1], name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64) tf.int64)
...@@ -91,13 +103,17 @@ class BertPretrainDataLoader(data_loader.DataLoader): ...@@ -91,13 +103,17 @@ class BertPretrainDataLoader(data_loader.DataLoader):
def _parse(self, record: Mapping[str, tf.Tensor]): def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model.""" """Parses raw tensors into a dict of tensors to be consumed by the model."""
x = { x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'], 'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids'],
'masked_lm_positions': record['masked_lm_positions'], 'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'], 'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'], 'masked_lm_weights': record['masked_lm_weights'],
} }
if self._params.use_v2_feature_names:
x['input_word_ids'] = record['input_word_ids']
x['input_type_ids'] = record['input_type_ids']
else:
x['input_word_ids'] = record['input_ids']
x['input_type_ids'] = record['segment_ids']
if self._use_next_sentence_label: if self._use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels'] x['next_sentence_labels'] = record['next_sentence_labels']
if self._use_position_id: if self._use_position_id:
......
...@@ -24,8 +24,12 @@ import tensorflow as tf ...@@ -24,8 +24,12 @@ import tensorflow as tf
from official.nlp.data import pretrain_dataloader from official.nlp.data import pretrain_dataloader
def _create_fake_dataset(output_path, seq_length, max_predictions_per_seq, def _create_fake_dataset(output_path,
use_position_id, use_next_sentence_label): seq_length,
max_predictions_per_seq,
use_position_id,
use_next_sentence_label,
use_v2_feature_names=False):
"""Creates a fake dataset.""" """Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path) writer = tf.io.TFRecordWriter(output_path)
...@@ -40,8 +44,12 @@ def _create_fake_dataset(output_path, seq_length, max_predictions_per_seq, ...@@ -40,8 +44,12 @@ def _create_fake_dataset(output_path, seq_length, max_predictions_per_seq,
for _ in range(100): for _ in range(100):
features = {} features = {}
input_ids = np.random.randint(100, size=(seq_length)) input_ids = np.random.randint(100, size=(seq_length))
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(np.ones_like(input_ids)) features["input_mask"] = create_int_feature(np.ones_like(input_ids))
if use_v2_feature_names:
features["input_word_ids"] = create_int_feature(input_ids)
features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
else:
features["input_ids"] = create_int_feature(input_ids)
features["segment_ids"] = create_int_feature(np.ones_like(input_ids)) features["segment_ids"] = create_int_feature(np.ones_like(input_ids))
features["masked_lm_positions"] = create_int_feature( features["masked_lm_positions"] = create_int_feature(
...@@ -102,6 +110,36 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase): ...@@ -102,6 +110,36 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
use_next_sentence_label) use_next_sentence_label)
self.assertEqual("position_ids" in features, use_position_id) self.assertEqual("position_ids" in features, use_position_id)
def test_v2_feature_names(self):
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
seq_length = 128
max_predictions_per_seq = 20
_create_fake_dataset(
train_data_path,
seq_length,
max_predictions_per_seq,
use_next_sentence_label=True,
use_position_id=False,
use_v2_feature_names=True)
data_config = pretrain_dataloader.BertPretrainDataConfig(
input_path=train_data_path,
max_predictions_per_seq=max_predictions_per_seq,
seq_length=seq_length,
global_batch_size=10,
is_training=True,
use_next_sentence_label=True,
use_position_id=False,
use_v2_feature_names=True)
dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
features = next(iter(dataset))
self.assertIn("input_word_ids", features)
self.assertIn("input_mask", features)
self.assertIn("input_type_ids", features)
self.assertIn("masked_lm_positions", features)
self.assertIn("masked_lm_ids", features)
self.assertIn("masked_lm_weights", features)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment