Commit 78e54775 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Let SentencePrediction task get accuracy/auc always and run extra metrics when they are needed.

report mrpc with f1 by default.
Remove todos.

PiperOrigin-RevId: 454653695
parent bfcf684e
...@@ -55,7 +55,7 @@ EVAL_METRIC_MAP = { ...@@ -55,7 +55,7 @@ EVAL_METRIC_MAP = {
'AX': 'matthews_corrcoef', 'AX': 'matthews_corrcoef',
'COLA': 'matthews_corrcoef', 'COLA': 'matthews_corrcoef',
'MNLI': 'cls_accuracy', 'MNLI': 'cls_accuracy',
'MRPC': 'cls_accuracy', 'MRPC': 'f1',
'QNLI': 'cls_accuracy', 'QNLI': 'cls_accuracy',
'QQP': 'f1', 'QQP': 'f1',
'RTE': 'cls_accuracy', 'RTE': 'cls_accuracy',
...@@ -93,12 +93,12 @@ def _override_exp_config_by_flags(exp_config, input_meta_data): ...@@ -93,12 +93,12 @@ def _override_exp_config_by_flags(exp_config, input_meta_data):
binary_helper.override_sentence_prediction_task_config, binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'], num_classes=input_meta_data['num_labels'],
metric_type='matthews_corrcoef') metric_type='matthews_corrcoef')
elif FLAGS.task_name in ('MNLI', 'MRPC', 'QNLI', 'RTE', 'SST-2', elif FLAGS.task_name in ('MNLI', 'QNLI', 'RTE', 'SST-2',
'WNLI'): 'WNLI'):
override_task_cfg_fn = functools.partial( override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config, binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels']) num_classes=input_meta_data['num_labels'])
elif FLAGS.task_name in ('QQP',): elif FLAGS.task_name in ('QQP', 'MRPC'):
override_task_cfg_fn = functools.partial( override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config, binary_helper.override_sentence_prediction_task_config,
metric_type='f1', metric_type='f1',
......
...@@ -165,14 +165,17 @@ class SentencePredictionTask(base_task.Task): ...@@ -165,14 +165,17 @@ class SentencePredictionTask(base_task.Task):
compiled_metrics.update_state(labels[self.label_field], model_outputs) compiled_metrics.update_state(labels[self.label_field], model_outputs)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy':
return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics)
features, labels = inputs, inputs features, labels = inputs, inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses( loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
if self.metric_type == 'matthews_corrcoef': if self.metric_type == 'matthews_corrcoef':
logs.update({ logs.update({
'sentence_prediction': # Ensure one prediction along batch dimension. 'sentence_prediction': # Ensure one prediction along batch dimension.
...@@ -207,7 +210,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -207,7 +210,7 @@ class SentencePredictionTask(base_task.Task):
labels = np.concatenate(aggregated_logs['labels'], axis=0) labels = np.concatenate(aggregated_logs['labels'], axis=0)
if self.metric_type == 'f1': if self.metric_type == 'f1':
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
return {self.metric_type: 100 * sklearn_metrics.f1_score(labels, preds)} return {self.metric_type: sklearn_metrics.f1_score(labels, preds)}
elif self.metric_type == 'matthews_corrcoef': elif self.metric_type == 'matthews_corrcoef':
preds = np.reshape(preds, -1) preds = np.reshape(preds, -1)
labels = np.reshape(labels, -1) labels = np.reshape(labels, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment