Unverified Commit b8dffcd5 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Handle Same Labels for ROC AUC (#1581)

* Handle corner case for ROC AUC

* Update

* Update doct
parent a304df5d
...@@ -173,7 +173,9 @@ class Meter(object): ...@@ -173,7 +173,9 @@ class Meter(object):
task_w = mask[:, task] task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0] task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0] task_y_pred = y_pred[:, task][task_w != 0]
scores.append(score_func(task_y_true, task_y_pred)) task_score = score_func(task_y_true, task_y_pred)
if task_score is not None:
scores.append(task_score)
return self._reduce_scores(scores, reduction) return self._reduce_scores(scores, reduction)
def pearson_r2(self, reduction='none'): def pearson_r2(self, reduction='none'):
...@@ -236,6 +238,9 @@ class Meter(object): ...@@ -236,6 +238,9 @@ class Meter(object):
def roc_auc_score(self, reduction='none'): def roc_auc_score(self, reduction='none'):
"""Compute roc-auc score for binary classification. """Compute roc-auc score for binary classification.
ROC-AUC scores are not well-defined in cases where labels for a task have one single
class only. In this case we will simply ignore this task and print a warning message.
Parameters Parameters
---------- ----------
reduction : 'none' or 'mean' or 'sum' reduction : 'none' or 'mean' or 'sum'
...@@ -253,7 +258,12 @@ class Meter(object): ...@@ -253,7 +258,12 @@ class Meter(object):
assert (self.mean is None) and (self.std is None), \ assert (self.mean is None) and (self.std is None), \
'Label normalization should not be performed for binary classification.' 'Label normalization should not be performed for binary classification.'
def score(y_true, y_pred): def score(y_true, y_pred):
return roc_auc_score(y_true.long().numpy(), torch.sigmoid(y_pred).numpy()) if len(y_true.unique()) == 1:
print('Warning: Only one class {} present in y_true for a task. '
'ROC AUC score is not defined in that case.'.format(y_true[0]))
return None
else:
return roc_auc_score(y_true.long().numpy(), torch.sigmoid(y_pred).numpy())
return self.multilabel_score(score, reduction) return self.multilabel_score(score, reduction)
def compute_metric(self, metric_name, reduction='none'): def compute_metric(self, metric_name, reduction='none'):
......
...@@ -17,8 +17,8 @@ def test_Meter(): ...@@ -17,8 +17,8 @@ def test_Meter():
# pearson r2 # pearson r2
meter = Meter(label_mean, label_std) meter = Meter(label_mean, label_std)
meter.update(label, pred) meter.update(pred, label)
true_scores = [0.7499999999999999, 0.7499999999999999] true_scores = [0.7500000774286983, 0.7500000516191412]
assert meter.pearson_r2() == true_scores assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores) assert meter.pearson_r2('mean') == np.mean(true_scores)
assert meter.pearson_r2('sum') == np.sum(true_scores) assert meter.pearson_r2('sum') == np.sum(true_scores)
...@@ -27,7 +27,7 @@ def test_Meter(): ...@@ -27,7 +27,7 @@ def test_Meter():
assert meter.compute_metric('r2', 'sum') == np.sum(true_scores) assert meter.compute_metric('r2', 'sum') == np.sum(true_scores)
meter = Meter(label_mean, label_std) meter = Meter(label_mean, label_std)
meter.update(label, pred, mask) meter.update(pred, label, mask)
true_scores = [1.0, 1.0] true_scores = [1.0, 1.0]
assert meter.pearson_r2() == true_scores assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores) assert meter.pearson_r2('mean') == np.mean(true_scores)
...@@ -38,7 +38,7 @@ def test_Meter(): ...@@ -38,7 +38,7 @@ def test_Meter():
# mae # mae
meter = Meter() meter = Meter()
meter.update(label, pred) meter.update(pred, label)
true_scores = [0.1666666716337204, 0.1666666716337204] true_scores = [0.1666666716337204, 0.1666666716337204]
assert meter.mae() == true_scores assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores) assert meter.mae('mean') == np.mean(true_scores)
...@@ -48,7 +48,7 @@ def test_Meter(): ...@@ -48,7 +48,7 @@ def test_Meter():
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores) assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)
meter = Meter() meter = Meter()
meter.update(label, pred, mask) meter.update(pred, label, mask)
true_scores = [0.25, 0.0] true_scores = [0.25, 0.0]
assert meter.mae() == true_scores assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores) assert meter.mae('mean') == np.mean(true_scores)
...@@ -57,23 +57,23 @@ def test_Meter(): ...@@ -57,23 +57,23 @@ def test_Meter():
assert meter.compute_metric('mae', 'mean') == np.mean(true_scores) assert meter.compute_metric('mae', 'mean') == np.mean(true_scores)
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores) assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)
# rmse # rmsef
meter = Meter(label_mean, label_std) meter = Meter(label_mean, label_std)
meter.update(label, pred) meter.update(pred, label)
true_scores = [0.22125875529784111, 0.5937311018897714] true_scores = [0.41068359261794546, 0.4106836107598449]
assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores)) assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores)) assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))
meter = Meter(label_mean, label_std) meter = Meter(label_mean, label_std)
meter.update(label, pred, mask) meter.update(pred, label, mask)
true_scores = [0.1337071188699867, 0.5019903799993205] true_scores = [0.44433766459035057, 0.5019903799993205]
assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores)) assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores)) assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))
# roc auc score # roc auc score
meter = Meter() meter = Meter()
meter.update(label, pred) meter.update(pred, label)
true_scores = [1.0, 0.75] true_scores = [1.0, 1.0]
assert meter.roc_auc_score() == true_scores assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores) assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores) assert meter.roc_auc_score('sum') == np.sum(true_scores)
...@@ -82,7 +82,7 @@ def test_Meter(): ...@@ -82,7 +82,7 @@ def test_Meter():
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores) assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)
meter = Meter() meter = Meter()
meter.update(label, pred, mask) meter.update(pred, label, mask)
true_scores = [1.0, 1.0] true_scores = [1.0, 1.0]
assert meter.roc_auc_score() == true_scores assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores) assert meter.roc_auc_score('mean') == np.mean(true_scores)
...@@ -91,5 +91,20 @@ def test_Meter(): ...@@ -91,5 +91,20 @@ def test_Meter():
assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores) assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores)
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores) assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)
def test_cases_with_undefined_scores():
label = torch.tensor([[0., 1.],
[0., 1.],
[1., 1.]])
pred = torch.tensor([[0.5, 0.5],
[0., 1.],
[1., 0.]])
meter = Meter()
meter.update(pred, label)
true_scores = [1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)
if __name__ == '__main__': if __name__ == '__main__':
test_Meter() test_Meter()
test_cases_with_undefined_scores()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment