You need to sign in or sign up before continuing.
Commit 969a3f34 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Let SentencePrediction task get accuracy/auc always and run extra metrics when they are needed.

report mrpc with f1 by default.
Remove todos.

PiperOrigin-RevId: 454653695
parent 04727817
...@@ -55,7 +55,7 @@ EVAL_METRIC_MAP = { ...@@ -55,7 +55,7 @@ EVAL_METRIC_MAP = {
'AX': 'matthews_corrcoef', 'AX': 'matthews_corrcoef',
'COLA': 'matthews_corrcoef', 'COLA': 'matthews_corrcoef',
'MNLI': 'cls_accuracy', 'MNLI': 'cls_accuracy',
'MRPC': 'cls_accuracy', 'MRPC': 'f1',
'QNLI': 'cls_accuracy', 'QNLI': 'cls_accuracy',
'QQP': 'f1', 'QQP': 'f1',
'RTE': 'cls_accuracy', 'RTE': 'cls_accuracy',
...@@ -93,12 +93,12 @@ def _override_exp_config_by_flags(exp_config, input_meta_data): ...@@ -93,12 +93,12 @@ def _override_exp_config_by_flags(exp_config, input_meta_data):
binary_helper.override_sentence_prediction_task_config, binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'], num_classes=input_meta_data['num_labels'],
metric_type='matthews_corrcoef') metric_type='matthews_corrcoef')
elif FLAGS.task_name in ('MNLI', 'MRPC', 'QNLI', 'RTE', 'SST-2', elif FLAGS.task_name in ('MNLI', 'QNLI', 'RTE', 'SST-2',
'WNLI'): 'WNLI'):
override_task_cfg_fn = functools.partial( override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config, binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels']) num_classes=input_meta_data['num_labels'])
elif FLAGS.task_name in ('QQP',): elif FLAGS.task_name in ('QQP', 'MRPC'):
override_task_cfg_fn = functools.partial( override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config, binary_helper.override_sentence_prediction_task_config,
metric_type='f1', metric_type='f1',
......
...@@ -165,14 +165,17 @@ class SentencePredictionTask(base_task.Task): ...@@ -165,14 +165,17 @@ class SentencePredictionTask(base_task.Task):
compiled_metrics.update_state(labels[self.label_field], model_outputs) compiled_metrics.update_state(labels[self.label_field], model_outputs)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy':
return super(SentencePredictionTask,
self).validation_step(inputs, model, metrics)
features, labels = inputs, inputs features, labels = inputs, inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses( loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
if self.metric_type == 'matthews_corrcoef': if self.metric_type == 'matthews_corrcoef':
logs.update({ logs.update({
'sentence_prediction': # Ensure one prediction along batch dimension. 'sentence_prediction': # Ensure one prediction along batch dimension.
...@@ -207,7 +210,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -207,7 +210,7 @@ class SentencePredictionTask(base_task.Task):
labels = np.concatenate(aggregated_logs['labels'], axis=0) labels = np.concatenate(aggregated_logs['labels'], axis=0)
if self.metric_type == 'f1': if self.metric_type == 'f1':
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
return {self.metric_type: 100 * sklearn_metrics.f1_score(labels, preds)} return {self.metric_type: sklearn_metrics.f1_score(labels, preds)}
elif self.metric_type == 'matthews_corrcoef': elif self.metric_type == 'matthews_corrcoef':
preds = np.reshape(preds, -1) preds = np.reshape(preds, -1)
labels = np.reshape(labels, -1) labels = np.reshape(labels, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment