".github/git@developer.sourcefind.cn:change/sglang.git" did not exist on "63a395b98517ee4a65476f8650919af43cc4c993"
Commit c90f8b16 authored by Rajagopal Ananthanarayanan's avatar Rajagopal Ananthanarayanan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 355088475
parent 790de575
......@@ -61,8 +61,7 @@ class BertPretrainDataLoader(data_loader.DataLoader):
self._use_next_sentence_label = params.use_next_sentence_label
self._use_position_id = params.use_position_id
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
def _name_to_features(self):
name_to_features = {
'input_mask':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
......@@ -89,7 +88,11 @@ class BertPretrainDataLoader(data_loader.DataLoader):
if self._use_position_id:
name_to_features['position_ids'] = tf.io.FixedLenFeature(
[self._seq_length], tf.int64)
return name_to_features
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = self._name_to_features()
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment