Commit c2a8f717 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 454382820
parent 7c16e618
......@@ -57,7 +57,7 @@ EVAL_METRIC_MAP = {
'MNLI': 'cls_accuracy',
'MRPC': 'cls_accuracy',
'QNLI': 'cls_accuracy',
'QQP': 'cls_accuracy',
'QQP': 'f1',
'RTE': 'cls_accuracy',
'SST-2': 'cls_accuracy',
'STS-B': 'pearson_spearman_corr',
......@@ -93,11 +93,16 @@ def _override_exp_config_by_flags(exp_config, input_meta_data):
binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'],
metric_type='matthews_corrcoef')
elif FLAGS.task_name in ('MNLI', 'MRPC', 'QNLI', 'QQP', 'RTE', 'SST-2',
elif FLAGS.task_name in ('MNLI', 'MRPC', 'QNLI', 'RTE', 'SST-2',
'WNLI'):
override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'])
elif FLAGS.task_name in ('QQP',):
override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config,
metric_type='f1',
num_classes=input_meta_data['num_labels'])
elif FLAGS.task_name in ('STS-B',):
override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config,
......
......@@ -34,7 +34,7 @@ from official.nlp.modeling import models
from official.nlp.tasks import utils
METRIC_TYPES = frozenset(
['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr'])
['accuracy', 'f1', 'matthews_corrcoef', 'pearson_spearman_corr'])
@dataclasses.dataclass
......@@ -180,7 +180,7 @@ class SentencePredictionTask(base_task.Task):
'labels':
labels[self.label_field],
})
if self.metric_type == 'pearson_spearman_corr':
else:
logs.update({
'sentence_prediction': outputs,
'labels': labels[self.label_field],
......@@ -202,18 +202,20 @@ class SentencePredictionTask(base_task.Task):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
if self.metric_type == 'accuracy':
return None
elif self.metric_type == 'matthews_corrcoef':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
preds = np.reshape(preds, -1)
labels = np.concatenate(aggregated_logs['labels'], axis=0)
if self.metric_type == 'f1':
preds = np.argmax(preds, axis=1)
return {self.metric_type: 100 * sklearn_metrics.f1_score(labels, preds)}
elif self.metric_type == 'matthews_corrcoef':
preds = np.reshape(preds, -1)
labels = np.reshape(labels, -1)
return {
self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
}
elif self.metric_type == 'pearson_spearman_corr':
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
preds = np.reshape(preds, -1)
labels = np.concatenate(aggregated_logs['labels'], axis=0)
labels = np.reshape(labels, -1)
pearson_corr = stats.pearsonr(preds, labels)[0]
spearman_corr = stats.spearmanr(preds, labels)[0]
......
......@@ -162,7 +162,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
self.assertLess(loss, 1.0)
@parameterized.parameters(("matthews_corrcoef", 2),
("pearson_spearman_corr", 1))
("pearson_spearman_corr", 1),
("f1", 2))
def test_np_metrics(self, metric_type, num_classes):
config = sentence_prediction.SentencePredictionConfig(
metric_type=metric_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment