Commit 06f74216 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add outputs_as_dict to SentencePredictionDataLoader.

PiperOrigin-RevId: 378190887
parent aa94accd
......@@ -40,6 +40,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type: str = 'int'
# Whether to include the example id number.
include_example_id: bool = False
outputs_as_dict: bool = False
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
......@@ -85,6 +86,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if self._include_example_id:
x['example_id'] = record['example_id']
if self._params.outputs_as_dict:
x['next_sentence_labels'] = record['label_ids']
return x
y = record['label_ids']
return (x, y)
......
......@@ -141,6 +141,28 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(labels.dtype, expected_label_type)
def test_load_dataset_as_dict(self):
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
batch_size = 10
seq_length = 128
_create_fake_preprocessed_dataset(input_path, seq_length, 'int')
data_config = loader.SentencePredictionDataConfig(
input_path=input_path,
seq_length=seq_length,
global_batch_size=batch_size,
label_type='int',
outputs_as_dict=True)
dataset = loader.SentencePredictionDataLoader(data_config).load()
features = next(iter(dataset))
self.assertCountEqual([
'input_word_ids', 'input_mask', 'input_type_ids', 'next_sentence_labels'
], features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
parameterized.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment