Commit e689246f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 423199224
parent c41d6565
...@@ -79,17 +79,29 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader): ...@@ -79,17 +79,29 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
name_to_features = { name_to_features = {
'input_ids': tf.io.VarLenFeature(tf.int64),
'input_mask': tf.io.VarLenFeature(tf.int64), 'input_mask': tf.io.VarLenFeature(tf.int64),
'segment_ids': tf.io.VarLenFeature(tf.int64),
'masked_lm_positions': tf.io.VarLenFeature(tf.int64), 'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
'masked_lm_ids': tf.io.VarLenFeature(tf.int64), 'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
'masked_lm_weights': tf.io.VarLenFeature(tf.float32), 'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
} }
if self._params.use_v2_feature_names:
input_ids_key = 'input_word_ids'
segment_key = 'input_type_ids'
name_to_features.update({
input_ids_key: tf.io.VarLenFeature(tf.int64),
segment_key: tf.io.VarLenFeature(tf.int64),
})
else:
input_ids_key = 'input_ids'
segment_key = 'segment_ids'
name_to_features.update({
input_ids_key: tf.io.VarLenFeature(tf.int64),
segment_key: tf.io.VarLenFeature(tf.int64),
})
if self._use_next_sentence_label: if self._use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1], name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64) tf.int64)
dynamic_keys = ['input_ids', 'input_mask', 'segment_ids'] dynamic_keys = [input_ids_key, 'input_mask', segment_key]
if self._use_position_id: if self._use_position_id:
name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64) name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
dynamic_keys.append('position_ids') dynamic_keys.append('position_ids')
...@@ -102,7 +114,7 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader): ...@@ -102,7 +114,7 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
# sequence length dimension. # sequence length dimension.
# Pad before the first non pad from the back should not be removed. # Pad before the first non pad from the back should not be removed.
mask = tf.math.greater( mask = tf.math.greater(
tf.math.cumsum(example['input_ids'], reverse=True), 0) tf.math.cumsum(example[input_ids_key], reverse=True), 0)
for key in dynamic_keys: for key in dynamic_keys:
example[key] = tf.boolean_mask(example[key], mask) example[key] = tf.boolean_mask(example[key], mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment