Commit 13100073 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 382655843
parent 00024735
...@@ -60,8 +60,8 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -60,8 +60,8 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
else: else:
self._label_name_mapping = dict() self._label_name_mapping = dict()
def _decode(self, record: tf.Tensor): def name_to_features_spec(self):
"""Decodes a serialized tf.Example.""" """Defines features to decode. Subclass may override to append features."""
label_type = LABEL_TYPES_MAP[self._params.label_type] label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
...@@ -72,7 +72,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -72,7 +72,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if self._include_example_id: if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features) return name_to_features
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
example = tf.io.parse_single_example(record, self.name_to_features_spec())
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32. # So cast all int64 to int32.
...@@ -86,20 +90,23 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -86,20 +90,23 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
def _parse(self, record: Mapping[str, tf.Tensor]): def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model.""" """Parses raw tensors into a dict of tensors to be consumed by the model."""
x = { key_mapping = {
'input_word_ids': record['input_ids'], 'input_ids': 'input_word_ids',
'input_mask': record['input_mask'], 'input_mask': 'input_mask',
'input_type_ids': record['segment_ids'] 'segment_ids': 'input_type_ids'
} }
if self._include_example_id: ret = {}
x['example_id'] = record['example_id'] for record_key in record:
if record_key in key_mapping:
x[self._label_field] = record[self._label_field] ret[key_mapping[record_key]] = record[record_key]
else:
ret[record_key] = record[record_key]
if self._label_field in self._label_name_mapping: if self._label_field in self._label_name_mapping:
x[self._label_name_mapping[self._label_field]] = record[self._label_field] ret[self._label_name_mapping[self._label_field]] = record[
self._label_field]
return x return ret
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment