Commit dcdd2e40 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Make the labels in sentence_prediction DataLoader/Task a dict.

PiperOrigin-RevId: 378751789
parent 927e31aa
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Loads dataset for the sentence prediction (classification) task.""" """Loads dataset for the sentence prediction (classification) task."""
import functools import functools
from typing import List, Mapping, Optional from typing import List, Mapping, Optional, Tuple
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
...@@ -40,7 +40,9 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -40,7 +40,9 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type: str = 'int' label_type: str = 'int'
# Whether to include the example id number. # Whether to include the example id number.
include_example_id: bool = False include_example_id: bool = False
outputs_as_dict: bool = False # Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels'
label_name: Optional[Tuple[str, str]] = None
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig) @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
...@@ -51,6 +53,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -51,6 +53,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
self._params = params self._params = params
self._seq_length = params.seq_length self._seq_length = params.seq_length
self._include_example_id = params.include_example_id self._include_example_id = params.include_example_id
if params.label_name:
self._label_name_mapping = dict([params.label_name])
else:
self._label_name_mapping = dict()
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
...@@ -86,12 +92,12 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -86,12 +92,12 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if self._include_example_id: if self._include_example_id:
x['example_id'] = record['example_id'] x['example_id'] = record['example_id']
if self._params.outputs_as_dict: x['label_ids'] = record['label_ids']
x['next_sentence_labels'] = record['label_ids']
return x if 'label_ids' in self._label_name_mapping:
x[self._label_name_mapping['label_ids']] = record['label_ids']
y = record['label_ids'] return x
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
...@@ -209,8 +215,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -209,8 +215,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
model_inputs = self._text_processor(segments) model_inputs = self._text_processor(segments)
if self._include_example_id: if self._include_example_id:
model_inputs['example_id'] = record['example_id'] model_inputs['example_id'] = record['example_id']
y = record[self._label_field] model_inputs['label_ids'] = record[self._label_field]
return model_inputs, y return model_inputs
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
......
...@@ -132,16 +132,17 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase): ...@@ -132,16 +132,17 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
global_batch_size=batch_size, global_batch_size=batch_size,
label_type=label_type) label_type=label_type)
dataset = loader.SentencePredictionDataLoader(data_config).load() dataset = loader.SentencePredictionDataLoader(data_config).load()
features, labels = next(iter(dataset)) features = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'], self.assertCountEqual(
features.keys()) ['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,)) self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(labels.dtype, expected_label_type) self.assertEqual(features['label_ids'].dtype, expected_label_type)
def test_load_dataset_as_dict(self): def test_load_dataset_with_label_mapping(self):
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record') input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
batch_size = 10 batch_size = 10
seq_length = 128 seq_length = 128
...@@ -151,15 +152,18 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase): ...@@ -151,15 +152,18 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
seq_length=seq_length, seq_length=seq_length,
global_batch_size=batch_size, global_batch_size=batch_size,
label_type='int', label_type='int',
outputs_as_dict=True) label_name=('label_ids', 'next_sentence_labels'))
dataset = loader.SentencePredictionDataLoader(data_config).load() dataset = loader.SentencePredictionDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
self.assertCountEqual([ self.assertCountEqual([
'input_word_ids', 'input_mask', 'input_type_ids', 'next_sentence_labels' 'input_word_ids', 'input_mask', 'input_type_ids',
'next_sentence_labels', 'label_ids'
], features.keys()) ], features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(features['label_ids'].dtype, tf.int32)
self.assertEqual(features['next_sentence_labels'].shape, (batch_size,)) self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
self.assertEqual(features['next_sentence_labels'].dtype, tf.int32) self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
...@@ -192,13 +196,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -192,13 +196,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
lower_case=lower_case, lower_case=lower_case,
vocab_file=vocab_file_path) vocab_file=vocab_file_path)
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset)) features = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'], self.assertCountEqual(
features.keys()) ['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,)) self.assertEqual(features['label_ids'].shape, (batch_size,))
@parameterized.parameters(True, False) @parameterized.parameters(True, False)
def test_python_sentencepiece_preprocessing(self, use_tfds): def test_python_sentencepiece_preprocessing(self, use_tfds):
...@@ -225,13 +230,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -225,13 +230,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
vocab_file=sp_model_file_path, vocab_file=sp_model_file_path,
) )
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset)) features = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'], self.assertCountEqual(
features.keys()) ['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,)) self.assertEqual(features['label_ids'].shape, (batch_size,))
@parameterized.parameters(True, False) @parameterized.parameters(True, False)
def test_saved_model_preprocessing(self, use_tfds): def test_saved_model_preprocessing(self, use_tfds):
...@@ -258,13 +264,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -258,13 +264,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
label_type='int' if use_tfds else 'float', label_type='int' if use_tfds else 'float',
) )
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset)) features = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'], self.assertCountEqual(
features.keys()) ['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,)) self.assertEqual(features['label_ids'].shape, (batch_size,))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -95,11 +95,12 @@ class SentencePredictionTask(base_task.Task): ...@@ -95,11 +95,12 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler=self.task_config.model.use_encoder_pooler) use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
label_ids = labels['label_ids']
if self.task_config.model.num_classes == 1: if self.task_config.model.num_classes == 1:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs) loss = tf.keras.losses.mean_squared_error(label_ids, model_outputs)
else: else:
loss = tf.keras.losses.sparse_categorical_crossentropy( loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, tf.cast(model_outputs, tf.float32), from_logits=True) label_ids, tf.cast(model_outputs, tf.float32), from_logits=True)
if aux_losses: if aux_losses:
loss += tf.add_n(aux_losses) loss += tf.add_n(aux_losses)
...@@ -120,7 +121,8 @@ class SentencePredictionTask(base_task.Task): ...@@ -120,7 +121,8 @@ class SentencePredictionTask(base_task.Task):
y = tf.zeros((1,), dtype=tf.float32) y = tf.zeros((1,), dtype=tf.float32)
else: else:
y = tf.zeros((1, 1), dtype=tf.int32) y = tf.zeros((1, 1), dtype=tf.int32)
return x, y x['label_ids'] = y
return x
dataset = tf.data.Dataset.range(1) dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat() dataset = dataset.repeat()
...@@ -142,7 +144,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -142,7 +144,7 @@ class SentencePredictionTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics: for metric in metrics:
metric.update_state(labels, model_outputs) metric.update_state(labels['label_ids'], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs) compiled_metrics.update_state(labels, model_outputs)
...@@ -151,7 +153,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -151,7 +153,7 @@ class SentencePredictionTask(base_task.Task):
if self.metric_type == 'accuracy': if self.metric_type == 'accuracy':
return super(SentencePredictionTask, return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics) self).validation_step(inputs, model, metrics)
features, labels = inputs features, labels = inputs, inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses( loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
...@@ -161,12 +163,12 @@ class SentencePredictionTask(base_task.Task): ...@@ -161,12 +163,12 @@ class SentencePredictionTask(base_task.Task):
'sentence_prediction': # Ensure one prediction along batch dimension. 'sentence_prediction': # Ensure one prediction along batch dimension.
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1), tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
'labels': 'labels':
labels, labels['label_ids'],
}) })
if self.metric_type == 'pearson_spearman_corr': if self.metric_type == 'pearson_spearman_corr':
logs.update({ logs.update({
'sentence_prediction': outputs, 'sentence_prediction': outputs,
'labels': labels, 'labels': labels['label_ids'],
}) })
return logs return logs
...@@ -250,7 +252,7 @@ def predict(task: SentencePredictionTask, ...@@ -250,7 +252,7 @@ def predict(task: SentencePredictionTask,
def predict_step(inputs): def predict_step(inputs):
"""Replicated prediction calculation.""" """Replicated prediction calculation."""
x, _ = inputs x = inputs
example_id = x.pop('example_id') example_id = x.pop('example_id')
outputs = task.inference_step(x, model) outputs = task.inference_step(x, model)
return dict(example_id=example_id, predictions=outputs) return dict(example_id=example_id, predictions=outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment