"src/vscode:/vscode.git/clone" did not exist on "ad222fb9399332e0d7aaadeec7c06f8b91126b78"
Unverified Commit b8dffcd5 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Handle Same Labels for ROC AUC (#1581)

* Handle corner case for ROC AUC

* Update

* Update doct
parent a304df5d
......@@ -173,7 +173,9 @@ class Meter(object):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0]
scores.append(score_func(task_y_true, task_y_pred))
task_score = score_func(task_y_true, task_y_pred)
if task_score is not None:
scores.append(task_score)
return self._reduce_scores(scores, reduction)
def pearson_r2(self, reduction='none'):
......@@ -236,6 +238,9 @@ class Meter(object):
def roc_auc_score(self, reduction='none'):
"""Compute roc-auc score for binary classification.
ROC-AUC scores are not well-defined in cases where labels for a task have one single
class only. In this case we will simply ignore this task and print a warning message.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
......@@ -253,6 +258,11 @@ class Meter(object):
assert (self.mean is None) and (self.std is None), \
'Label normalization should not be performed for binary classification.'
def score(y_true, y_pred):
if len(y_true.unique()) == 1:
print('Warning: Only one class {} present in y_true for a task. '
'ROC AUC score is not defined in that case.'.format(y_true[0]))
return None
else:
return roc_auc_score(y_true.long().numpy(), torch.sigmoid(y_pred).numpy())
return self.multilabel_score(score, reduction)
......
......@@ -17,8 +17,8 @@ def test_Meter():
# pearson r2
meter = Meter(label_mean, label_std)
meter.update(label, pred)
true_scores = [0.7499999999999999, 0.7499999999999999]
meter.update(pred, label)
true_scores = [0.7500000774286983, 0.7500000516191412]
assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores)
assert meter.pearson_r2('sum') == np.sum(true_scores)
......@@ -27,7 +27,7 @@ def test_Meter():
assert meter.compute_metric('r2', 'sum') == np.sum(true_scores)
meter = Meter(label_mean, label_std)
meter.update(label, pred, mask)
meter.update(pred, label, mask)
true_scores = [1.0, 1.0]
assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores)
......@@ -38,7 +38,7 @@ def test_Meter():
# mae
meter = Meter()
meter.update(label, pred)
meter.update(pred, label)
true_scores = [0.1666666716337204, 0.1666666716337204]
assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores)
......@@ -48,7 +48,7 @@ def test_Meter():
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)
meter = Meter()
meter.update(label, pred, mask)
meter.update(pred, label, mask)
true_scores = [0.25, 0.0]
assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores)
......@@ -57,23 +57,23 @@ def test_Meter():
assert meter.compute_metric('mae', 'mean') == np.mean(true_scores)
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)
# rmse
# rmsef
meter = Meter(label_mean, label_std)
meter.update(label, pred)
true_scores = [0.22125875529784111, 0.5937311018897714]
meter.update(pred, label)
true_scores = [0.41068359261794546, 0.4106836107598449]
assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))
meter = Meter(label_mean, label_std)
meter.update(label, pred, mask)
true_scores = [0.1337071188699867, 0.5019903799993205]
meter.update(pred, label, mask)
true_scores = [0.44433766459035057, 0.5019903799993205]
assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))
# roc auc score
meter = Meter()
meter.update(label, pred)
true_scores = [1.0, 0.75]
meter.update(pred, label)
true_scores = [1.0, 1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)
......@@ -82,7 +82,7 @@ def test_Meter():
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)
meter = Meter()
meter.update(label, pred, mask)
meter.update(pred, label, mask)
true_scores = [1.0, 1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
......@@ -91,5 +91,20 @@ def test_Meter():
assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores)
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)
def test_cases_with_undefined_scores():
label = torch.tensor([[0., 1.],
[0., 1.],
[1., 1.]])
pred = torch.tensor([[0.5, 0.5],
[0., 1.],
[1., 0.]])
meter = Meter()
meter.update(pred, label)
true_scores = [1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)
if __name__ == '__main__':
test_Meter()
test_cases_with_undefined_scores()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment