"vscode:/vscode.git/clone" did not exist on "5319098e2a2357dca2b144dfc005df234cb7ca79"
Commit 0e74158f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 384018258
parent 5f23689e
......@@ -222,13 +222,12 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
"""Berts preprocess."""
segments = [record[x] for x in self._text_fields]
model_inputs = self._text_processor(segments)
if self._include_example_id:
model_inputs['example_id'] = record['example_id']
model_inputs[self._label_field] = record[self._label_field]
for key in record:
if key not in self._text_fields:
model_inputs[key] = record[key]
return model_inputs
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
def name_to_features_spec(self):
name_to_features = {}
for text_field in self._text_fields:
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
......@@ -237,8 +236,11 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
return name_to_features
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
example = tf.io.parse_single_example(record, self.name_to_features_spec())
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
......
......@@ -198,9 +198,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
expected_keys = [
'input_word_ids', 'input_type_ids', 'input_mask', label_field
]
if use_tfds:
expected_keys += ['idx']
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
......@@ -233,9 +236,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
expected_keys = [
'input_word_ids', 'input_type_ids', 'input_mask', label_field
]
if use_tfds:
expected_keys += ['idx']
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
......@@ -268,9 +274,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
expected_keys = [
'input_word_ids', 'input_type_ids', 'input_mask', label_field
]
if use_tfds:
expected_keys += ['idx']
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment