Commit e77956d6 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 378911698
parent bab477df
...@@ -40,6 +40,7 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -40,6 +40,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type: str = 'int' label_type: str = 'int'
# Whether to include the example id number. # Whether to include the example id number.
include_example_id: bool = False include_example_id: bool = False
label_field: str = 'label_ids'
# Maps the key in TfExample to feature name. # Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels' # E.g 'label_ids' to 'next_sentence_labels'
label_name: Optional[Tuple[str, str]] = None label_name: Optional[Tuple[str, str]] = None
...@@ -53,6 +54,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -53,6 +54,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
self._params = params self._params = params
self._seq_length = params.seq_length self._seq_length = params.seq_length
self._include_example_id = params.include_example_id self._include_example_id = params.include_example_id
self._label_field = params.label_field
if params.label_name: if params.label_name:
self._label_name_mapping = dict([params.label_name]) self._label_name_mapping = dict([params.label_name])
else: else:
...@@ -65,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -65,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], label_type), self._label_field: tf.io.FixedLenFeature([], label_type),
} }
if self._include_example_id: if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
...@@ -92,10 +94,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -92,10 +94,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if self._include_example_id: if self._include_example_id:
x['example_id'] = record['example_id'] x['example_id'] = record['example_id']
x['label_ids'] = record['label_ids'] x[self._label_field] = record[self._label_field]
if 'label_ids' in self._label_name_mapping: if self._label_field in self._label_name_mapping:
x[self._label_name_mapping['label_ids']] = record['label_ids'] x[self._label_name_mapping[self._label_field]] = record[self._label_field]
return x return x
...@@ -215,7 +217,7 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader): ...@@ -215,7 +217,7 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
model_inputs = self._text_processor(segments) model_inputs = self._text_processor(segments)
if self._include_example_id: if self._include_example_id:
model_inputs['example_id'] = record['example_id'] model_inputs['example_id'] = record['example_id']
model_inputs['label_ids'] = record[self._label_field] model_inputs[self._label_field] = record[self._label_field]
return model_inputs return model_inputs
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
......
...@@ -197,13 +197,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -197,13 +197,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
vocab_file=vocab_file_path) vocab_file=vocab_file_path)
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual( self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'], ['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys()) features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,)) self.assertEqual(features[label_field].shape, (batch_size,))
@parameterized.parameters(True, False) @parameterized.parameters(True, False)
def test_python_sentencepiece_preprocessing(self, use_tfds): def test_python_sentencepiece_preprocessing(self, use_tfds):
...@@ -231,13 +232,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -231,13 +232,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
) )
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual( self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'], ['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys()) features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,)) self.assertEqual(features[label_field].shape, (batch_size,))
@parameterized.parameters(True, False) @parameterized.parameters(True, False)
def test_saved_model_preprocessing(self, use_tfds): def test_saved_model_preprocessing(self, use_tfds):
...@@ -265,13 +267,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase, ...@@ -265,13 +267,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
) )
dataset = loader.SentencePredictionTextDataLoader(data_config).load() dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset)) features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual( self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'], ['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys()) features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,)) self.assertEqual(features[label_field].shape, (batch_size,))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task): ...@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task):
if params.metric_type not in METRIC_TYPES: if params.metric_type not in METRIC_TYPES:
raise ValueError('Invalid metric_type: {}'.format(params.metric_type)) raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
self.metric_type = params.metric_type self.metric_type = params.metric_type
if hasattr(params.train_data, 'label_field'):
self.label_field = params.train_data.label_field
else:
self.label_field = 'label_ids'
def build_model(self): def build_model(self):
if self.task_config.hub_module_url and self.task_config.init_checkpoint: if self.task_config.hub_module_url and self.task_config.init_checkpoint:
...@@ -95,7 +99,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -95,7 +99,7 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler=self.task_config.model.use_encoder_pooler) use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
label_ids = labels['label_ids'] label_ids = labels[self.label_field]
if self.task_config.model.num_classes == 1: if self.task_config.model.num_classes == 1:
loss = tf.keras.losses.mean_squared_error(label_ids, model_outputs) loss = tf.keras.losses.mean_squared_error(label_ids, model_outputs)
else: else:
...@@ -121,7 +125,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -121,7 +125,7 @@ class SentencePredictionTask(base_task.Task):
y = tf.zeros((1,), dtype=tf.float32) y = tf.zeros((1,), dtype=tf.float32)
else: else:
y = tf.zeros((1, 1), dtype=tf.int32) y = tf.zeros((1, 1), dtype=tf.int32)
x['label_ids'] = y x[self.label_field] = y
return x return x
dataset = tf.data.Dataset.range(1) dataset = tf.data.Dataset.range(1)
...@@ -144,10 +148,10 @@ class SentencePredictionTask(base_task.Task): ...@@ -144,10 +148,10 @@ class SentencePredictionTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics: for metric in metrics:
metric.update_state(labels['label_ids'], model_outputs) metric.update_state(labels[self.label_field], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs) compiled_metrics.update_state(labels[self.label_field], model_outputs)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy': if self.metric_type == 'accuracy':
...@@ -163,12 +167,12 @@ class SentencePredictionTask(base_task.Task): ...@@ -163,12 +167,12 @@ class SentencePredictionTask(base_task.Task):
'sentence_prediction': # Ensure one prediction along batch dimension. 'sentence_prediction': # Ensure one prediction along batch dimension.
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1), tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
'labels': 'labels':
labels['label_ids'], labels[self.label_field],
}) })
if self.metric_type == 'pearson_spearman_corr': if self.metric_type == 'pearson_spearman_corr':
logs.update({ logs.update({
'sentence_prediction': outputs, 'sentence_prediction': outputs,
'labels': labels['label_ids'], 'labels': labels[self.label_field],
}) })
return logs return logs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment