Commit 78e54775 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Let SentencePrediction task get accuracy/auc always and run extra metrics when they are needed.

report mrpc with f1 by default.
Remove todos.

PiperOrigin-RevId: 454653695
parent bfcf684e
......@@ -55,7 +55,7 @@ EVAL_METRIC_MAP = {
'AX': 'matthews_corrcoef',
'COLA': 'matthews_corrcoef',
'MNLI': 'cls_accuracy',
'MRPC': 'cls_accuracy',
'MRPC': 'f1',
'QNLI': 'cls_accuracy',
'QQP': 'f1',
'RTE': 'cls_accuracy',
......@@ -93,12 +93,12 @@ def _override_exp_config_by_flags(exp_config, input_meta_data):
binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'],
metric_type='matthews_corrcoef')
elif FLAGS.task_name in ('MNLI', 'MRPC', 'QNLI', 'RTE', 'SST-2',
elif FLAGS.task_name in ('MNLI', 'QNLI', 'RTE', 'SST-2',
'WNLI'):
override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'])
elif FLAGS.task_name in ('QQP',):
elif FLAGS.task_name in ('QQP', 'MRPC'):
override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config,
metric_type='f1',
......
......@@ -165,14 +165,17 @@ class SentencePredictionTask(base_task.Task):
compiled_metrics.update_state(labels[self.label_field], model_outputs)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy':
return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics)
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
if self.metric_type == 'matthews_corrcoef':
logs.update({
'sentence_prediction': # Ensure one prediction along batch dimension.
......@@ -207,7 +210,7 @@ class SentencePredictionTask(base_task.Task):
labels = np.concatenate(aggregated_logs['labels'], axis=0)
if self.metric_type == 'f1':
preds = np.argmax(preds, axis=1)
return {self.metric_type: 100 * sklearn_metrics.f1_score(labels, preds)}
return {self.metric_type: sklearn_metrics.f1_score(labels, preds)}
elif self.metric_type == 'matthews_corrcoef':
preds = np.reshape(preds, -1)
labels = np.reshape(labels, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment