Commit de5c2758 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Make the labels in sentence_prediction DataLoader/Task a dict.

PiperOrigin-RevId: 378751789
parent c633f2c8
......@@ -14,7 +14,7 @@
"""Loads dataset for the sentence prediction (classification) task."""
import functools
from typing import List, Mapping, Optional
from typing import List, Mapping, Optional, Tuple
import dataclasses
import tensorflow as tf
......@@ -40,7 +40,9 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type: str = 'int'
# Whether to include the example id number.
include_example_id: bool = False
outputs_as_dict: bool = False
# Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels'
label_name: Optional[Tuple[str, str]] = None
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
......@@ -51,6 +53,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
self._params = params
self._seq_length = params.seq_length
self._include_example_id = params.include_example_id
if params.label_name:
self._label_name_mapping = dict([params.label_name])
else:
self._label_name_mapping = dict()
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
......@@ -86,12 +92,12 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if self._include_example_id:
x['example_id'] = record['example_id']
if self._params.outputs_as_dict:
x['next_sentence_labels'] = record['label_ids']
return x
x['label_ids'] = record['label_ids']
if 'label_ids' in self._label_name_mapping:
x[self._label_name_mapping['label_ids']] = record['label_ids']
y = record['label_ids']
return (x, y)
return x
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
......@@ -209,8 +215,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
model_inputs = self._text_processor(segments)
if self._include_example_id:
model_inputs['example_id'] = record['example_id']
y = record[self._label_field]
return model_inputs, y
model_inputs['label_ids'] = record[self._label_field]
return model_inputs
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
......
......@@ -132,16 +132,17 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
global_batch_size=batch_size,
label_type=label_type)
dataset = loader.SentencePredictionDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'],
features.keys())
features = next(iter(dataset))
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(labels.dtype, expected_label_type)
self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(features['label_ids'].dtype, expected_label_type)
def test_load_dataset_as_dict(self):
def test_load_dataset_with_label_mapping(self):
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
batch_size = 10
seq_length = 128
......@@ -151,15 +152,18 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
seq_length=seq_length,
global_batch_size=batch_size,
label_type='int',
outputs_as_dict=True)
label_name=('label_ids', 'next_sentence_labels'))
dataset = loader.SentencePredictionDataLoader(data_config).load()
features = next(iter(dataset))
self.assertCountEqual([
'input_word_ids', 'input_mask', 'input_type_ids', 'next_sentence_labels'
'input_word_ids', 'input_mask', 'input_type_ids',
'next_sentence_labels', 'label_ids'
], features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(features['label_ids'].dtype, tf.int32)
self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
......@@ -192,13 +196,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
lower_case=lower_case,
vocab_file=vocab_file_path)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features.keys())
features = next(iter(dataset))
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(features['label_ids'].shape, (batch_size,))
@parameterized.parameters(True, False)
def test_python_sentencepiece_preprocessing(self, use_tfds):
......@@ -225,13 +230,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
vocab_file=sp_model_file_path,
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features.keys())
features = next(iter(dataset))
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(features['label_ids'].shape, (batch_size,))
@parameterized.parameters(True, False)
def test_saved_model_preprocessing(self, use_tfds):
......@@ -258,13 +264,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
label_type='int' if use_tfds else 'float',
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features.keys())
features = next(iter(dataset))
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(features['label_ids'].shape, (batch_size,))
if __name__ == '__main__':
......
......@@ -95,11 +95,12 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
label_ids = labels['label_ids']
if self.task_config.model.num_classes == 1:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
loss = tf.keras.losses.mean_squared_error(label_ids, model_outputs)
else:
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, tf.cast(model_outputs, tf.float32), from_logits=True)
label_ids, tf.cast(model_outputs, tf.float32), from_logits=True)
if aux_losses:
loss += tf.add_n(aux_losses)
......@@ -120,7 +121,8 @@ class SentencePredictionTask(base_task.Task):
y = tf.zeros((1,), dtype=tf.float32)
else:
y = tf.zeros((1, 1), dtype=tf.int32)
return x, y
x['label_ids'] = y
return x
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
......@@ -142,7 +144,7 @@ class SentencePredictionTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
metric.update_state(labels, model_outputs)
metric.update_state(labels['label_ids'], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs)
......@@ -151,7 +153,7 @@ class SentencePredictionTask(base_task.Task):
if self.metric_type == 'accuracy':
return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics)
features, labels = inputs
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
......@@ -161,12 +163,12 @@ class SentencePredictionTask(base_task.Task):
'sentence_prediction': # Ensure one prediction along batch dimension.
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
'labels':
labels,
labels['label_ids'],
})
if self.metric_type == 'pearson_spearman_corr':
logs.update({
'sentence_prediction': outputs,
'labels': labels,
'labels': labels['label_ids'],
})
return logs
......@@ -250,7 +252,7 @@ def predict(task: SentencePredictionTask,
def predict_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
x = inputs
example_id = x.pop('example_id')
outputs = task.inference_step(x, model)
return dict(example_id=example_id, predictions=outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment