Commit c430556b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add `include_example_field` into `SentencePredictionTextDataLoader` so that we...

Add `include_example_field` into `SentencePredictionTextDataLoader` so that we can use the data loader in the predict step of `SentencePrediction` task.

PiperOrigin-RevId: 369215827
parent ffaa4035
...@@ -123,6 +123,7 @@ class SentencePredictionTextDataConfig(cfg.DataConfig): ...@@ -123,6 +123,7 @@ class SentencePredictionTextDataConfig(cfg.DataConfig):
preprocessing_hub_module_url: str = '' preprocessing_hub_module_url: str = ''
# Either tfrecord or sstsable or recordio. # Either tfrecord or sstsable or recordio.
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
include_example_id: bool = False
class TextProcessor(tf.Module): class TextProcessor(tf.Module):
...@@ -189,6 +190,7 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -189,6 +190,7 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
self._text_fields = params.text_fields self._text_fields = params.text_fields
self._label_field = params.label_field self._label_field = params.label_field
self._label_type = params.label_type self._label_type = params.label_type
self._include_example_id = params.include_example_id
self._text_processor = TextProcessor( self._text_processor = TextProcessor(
seq_length=params.seq_length, seq_length=params.seq_length,
vocab_file=params.vocab_file, vocab_file=params.vocab_file,
...@@ -200,6 +202,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -200,6 +202,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
"""Berts preprocess.""" """Berts preprocess."""
segments = [record[x] for x in self._text_fields] segments = [record[x] for x in self._text_fields]
model_inputs = self._text_processor(segments) model_inputs = self._text_processor(segments)
if self._include_example_id:
model_inputs['example_id'] = record['example_id']
y = record[self._label_field] y = record[self._label_field]
return model_inputs, y return model_inputs, y
...@@ -211,6 +215,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -211,6 +215,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
label_type = LABEL_TYPES_MAP[self._label_type] label_type = LABEL_TYPES_MAP[self._label_type]
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type) name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment