Commit c430556b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add `include_example_field` into `SentencePredictionTextDataLoader` so that we...

Add `include_example_field` into `SentencePredictionTextDataLoader` so that we can use the data loader in the predict step of `SentencePrediction` task.

PiperOrigin-RevId: 369215827
parent ffaa4035
......@@ -123,6 +123,7 @@ class SentencePredictionTextDataConfig(cfg.DataConfig):
preprocessing_hub_module_url: str = ''
# Either tfrecord or sstsable or recordio.
file_type: str = 'tfrecord'
include_example_id: bool = False
class TextProcessor(tf.Module):
......@@ -189,6 +190,7 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
self._text_fields = params.text_fields
self._label_field = params.label_field
self._label_type = params.label_type
self._include_example_id = params.include_example_id
self._text_processor = TextProcessor(
seq_length=params.seq_length,
vocab_file=params.vocab_file,
......@@ -200,6 +202,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
"""Berts preprocess."""
segments = [record[x] for x in self._text_fields]
model_inputs = self._text_processor(segments)
if self._include_example_id:
model_inputs['example_id'] = record['example_id']
y = record[self._label_field]
return model_inputs, y
......@@ -211,6 +215,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
label_type = LABEL_TYPES_MAP[self._label_type]
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment