Commit d6f7c76a authored by Rajagopal Ananthanarayanan's avatar Rajagopal Ananthanarayanan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 355088475
parent 830551b8
...@@ -61,8 +61,7 @@ class BertPretrainDataLoader(data_loader.DataLoader): ...@@ -61,8 +61,7 @@ class BertPretrainDataLoader(data_loader.DataLoader):
self._use_next_sentence_label = params.use_next_sentence_label self._use_next_sentence_label = params.use_next_sentence_label
self._use_position_id = params.use_position_id self._use_position_id = params.use_position_id
def _decode(self, record: tf.Tensor): def _name_to_features(self):
"""Decodes a serialized tf.Example."""
name_to_features = { name_to_features = {
'input_mask': 'input_mask':
tf.io.FixedLenFeature([self._seq_length], tf.int64), tf.io.FixedLenFeature([self._seq_length], tf.int64),
...@@ -89,7 +88,11 @@ class BertPretrainDataLoader(data_loader.DataLoader): ...@@ -89,7 +88,11 @@ class BertPretrainDataLoader(data_loader.DataLoader):
if self._use_position_id: if self._use_position_id:
name_to_features['position_ids'] = tf.io.FixedLenFeature( name_to_features['position_ids'] = tf.io.FixedLenFeature(
[self._seq_length], tf.int64) [self._seq_length], tf.int64)
return name_to_features
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = self._name_to_features()
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment