Commit e77956d6 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 378911698
parent bab477df
......@@ -40,6 +40,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type: str = 'int'
# Whether to include the example id number.
include_example_id: bool = False
label_field: str = 'label_ids'
# Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels'
label_name: Optional[Tuple[str, str]] = None
......@@ -53,6 +54,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
self._params = params
self._seq_length = params.seq_length
self._include_example_id = params.include_example_id
self._label_field = params.label_field
if params.label_name:
self._label_name_mapping = dict([params.label_name])
else:
......@@ -65,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], label_type),
self._label_field: tf.io.FixedLenFeature([], label_type),
}
if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
......@@ -92,10 +94,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if self._include_example_id:
x['example_id'] = record['example_id']
x['label_ids'] = record['label_ids']
x[self._label_field] = record[self._label_field]
if 'label_ids' in self._label_name_mapping:
x[self._label_name_mapping['label_ids']] = record['label_ids']
if self._label_field in self._label_name_mapping:
x[self._label_name_mapping[self._label_field]] = record[self._label_field]
return x
......@@ -215,7 +217,7 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
model_inputs = self._text_processor(segments)
if self._include_example_id:
model_inputs['example_id'] = record['example_id']
model_inputs['label_ids'] = record[self._label_field]
model_inputs[self._label_field] = record[self._label_field]
return model_inputs
def _decode(self, record: tf.Tensor):
......
......@@ -197,13 +197,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
vocab_file=vocab_file_path)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(features[label_field].shape, (batch_size,))
@parameterized.parameters(True, False)
def test_python_sentencepiece_preprocessing(self, use_tfds):
......@@ -231,13 +232,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(features[label_field].shape, (batch_size,))
@parameterized.parameters(True, False)
def test_saved_model_preprocessing(self, use_tfds):
......@@ -265,13 +267,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features = next(iter(dataset))
label_field = data_config.label_field
self.assertCountEqual(
['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
['input_word_ids', 'input_type_ids', 'input_mask', label_field],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['label_ids'].shape, (batch_size,))
self.assertEqual(features[label_field].shape, (batch_size,))
if __name__ == '__main__':
......
......@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task):
if params.metric_type not in METRIC_TYPES:
raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
self.metric_type = params.metric_type
if hasattr(params.train_data, 'label_field'):
self.label_field = params.train_data.label_field
else:
self.label_field = 'label_ids'
def build_model(self):
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
......@@ -95,7 +99,7 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
label_ids = labels['label_ids']
label_ids = labels[self.label_field]
if self.task_config.model.num_classes == 1:
loss = tf.keras.losses.mean_squared_error(label_ids, model_outputs)
else:
......@@ -121,7 +125,7 @@ class SentencePredictionTask(base_task.Task):
y = tf.zeros((1,), dtype=tf.float32)
else:
y = tf.zeros((1, 1), dtype=tf.int32)
x['label_ids'] = y
x[self.label_field] = y
return x
dataset = tf.data.Dataset.range(1)
......@@ -144,10 +148,10 @@ class SentencePredictionTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
metric.update_state(labels['label_ids'], model_outputs)
metric.update_state(labels[self.label_field], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs)
compiled_metrics.update_state(labels[self.label_field], model_outputs)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy':
......@@ -163,12 +167,12 @@ class SentencePredictionTask(base_task.Task):
'sentence_prediction': # Ensure one prediction along batch dimension.
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
'labels':
labels['label_ids'],
labels[self.label_field],
})
if self.metric_type == 'pearson_spearman_corr':
logs.update({
'sentence_prediction': outputs,
'labels': labels['label_ids'],
'labels': labels[self.label_field],
})
return logs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment