Commit 0e74158f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 384018258
parent 5f23689e
...@@ -222,13 +222,12 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -222,13 +222,12 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
"""Berts preprocess.""" """Berts preprocess."""
segments = [record[x] for x in self._text_fields] segments = [record[x] for x in self._text_fields]
model_inputs = self._text_processor(segments) model_inputs = self._text_processor(segments)
if self._include_example_id: for key in record:
model_inputs['example_id'] = record['example_id'] if key not in self._text_fields:
model_inputs[self._label_field] = record[self._label_field] model_inputs[key] = record[key]
return model_inputs return model_inputs
def _decode(self, record: tf.Tensor): def name_to_features_spec(self):
"""Decodes a serialized tf.Example."""
name_to_features = {} name_to_features = {}
for text_field in self._text_fields: for text_field in self._text_fields:
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string) name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
...@@ -237,8 +236,11 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -237,8 +236,11 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type) name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
if self._include_example_id: if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features) return name_to_features
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
example = tf.io.parse_single_example(record, self.name_to_features_spec())
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32. # So cast all int64 to int32.
for name in example: for name in example:
......
...@@ -198,9 +198,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -198,9 +198,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
label_field = data_config.label_field label_field = data_config.label_field
self.assertCountEqual( expected_keys = [
['input_word_ids', 'input_type_ids', 'input_mask', label_field], 'input_word_ids', 'input_type_ids', 'input_mask', label_field
features.keys()) ]
if use_tfds:
expected_keys += ['idx']
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
...@@ -233,9 +236,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -233,9 +236,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
label_field = data_config.label_field label_field = data_config.label_field
self.assertCountEqual( expected_keys = [
['input_word_ids', 'input_type_ids', 'input_mask', label_field], 'input_word_ids', 'input_type_ids', 'input_mask', label_field
features.keys()) ]
if use_tfds:
expected_keys += ['idx']
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
...@@ -268,9 +274,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -268,9 +274,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
label_field = data_config.label_field label_field = data_config.label_field
self.assertCountEqual( expected_keys = [
['input_word_ids', 'input_type_ids', 'input_mask', label_field], 'input_word_ids', 'input_type_ids', 'input_mask', label_field
features.keys()) ]
if use_tfds:
expected_keys += ['idx']
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment