Unverified Commit ed4134ed authored by Zekuan (Kay) Liu's avatar Zekuan (Kay) Liu Committed by GitHub
Browse files

[Example] fix auc in caregnn example (#3647)


Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent 57476371
...@@ -3,6 +3,7 @@ import argparse ...@@ -3,6 +3,7 @@ import argparse
import torch as th import torch as th
from model import CAREGNN from model import CAREGNN
import torch.optim as optim import torch.optim as optim
from torch.nn.functional import softmax
from sklearn.metrics import recall_score, roc_auc_score from sklearn.metrics import recall_score, roc_auc_score
from utils import EarlyStopping from utils import EarlyStopping
...@@ -70,13 +71,13 @@ def main(args): ...@@ -70,13 +71,13 @@ def main(args):
args.sim_weight * loss_fn(logits_sim[train_idx], labels[train_idx]) args.sim_weight * loss_fn(logits_sim[train_idx], labels[train_idx])
tr_recall = recall_score(labels[train_idx].cpu(), logits_gnn.data[train_idx].argmax(dim=1).cpu()) tr_recall = recall_score(labels[train_idx].cpu(), logits_gnn.data[train_idx].argmax(dim=1).cpu())
tr_auc = roc_auc_score(labels[train_idx].cpu(), logits_gnn.data[train_idx][:, 1].cpu()) tr_auc = roc_auc_score(labels[train_idx].cpu(), softmax(logits_gnn, dim=1).data[train_idx][:, 1].cpu())
# validation # validation
val_loss = loss_fn(logits_gnn[val_idx], labels[val_idx]) + \ val_loss = loss_fn(logits_gnn[val_idx], labels[val_idx]) + \
args.sim_weight * loss_fn(logits_sim[val_idx], labels[val_idx]) args.sim_weight * loss_fn(logits_sim[val_idx], labels[val_idx])
val_recall = recall_score(labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu()) val_recall = recall_score(labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu())
val_auc = roc_auc_score(labels[val_idx].cpu(), logits_gnn.data[val_idx][:, 1].cpu()) val_auc = roc_auc_score(labels[val_idx].cpu(), softmax(logits_gnn, dim=1).data[val_idx][:, 1].cpu())
# backward # backward
optimizer.zero_grad() optimizer.zero_grad()
...@@ -106,7 +107,7 @@ def main(args): ...@@ -106,7 +107,7 @@ def main(args):
test_loss = loss_fn(logits_gnn[test_idx], labels[test_idx]) + \ test_loss = loss_fn(logits_gnn[test_idx], labels[test_idx]) + \
args.sim_weight * loss_fn(logits_sim[test_idx], labels[test_idx]) args.sim_weight * loss_fn(logits_sim[test_idx], labels[test_idx])
test_recall = recall_score(labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu()) test_recall = recall_score(labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu())
test_auc = roc_auc_score(labels[test_idx].cpu(), logits_gnn.data[test_idx][:, 1].cpu()) test_auc = roc_auc_score(labels[test_idx].cpu(), softmax(logits_gnn, dim=1).data[test_idx][:, 1].cpu())
print("Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}".format(test_recall, test_auc, test_loss.item())) print("Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}".format(test_recall, test_auc, test_loss.item()))
......
...@@ -2,6 +2,7 @@ import dgl ...@@ -2,6 +2,7 @@ import dgl
import argparse import argparse
import torch as th import torch as th
import torch.optim as optim import torch.optim as optim
from torch.nn.functional import softmax
from sklearn.metrics import roc_auc_score, recall_score from sklearn.metrics import roc_auc_score, recall_score
from utils import EarlyStopping from utils import EarlyStopping
...@@ -22,7 +23,7 @@ def evaluate(model, loss_fn, dataloader, device='cpu'): ...@@ -22,7 +23,7 @@ def evaluate(model, loss_fn, dataloader, device='cpu'):
# compute loss # compute loss
loss += loss_fn(logits_gnn, label).item() + args.sim_weight * loss_fn(logits_sim, label).item() loss += loss_fn(logits_gnn, label).item() + args.sim_weight * loss_fn(logits_sim, label).item()
recall += recall_score(label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()) recall += recall_score(label.cpu(), logits_gnn.argmax(dim=1).detach().cpu())
auc += roc_auc_score(label.cpu(), logits_gnn[:, 1].detach().cpu()) auc += roc_auc_score(label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu())
num_blocks += 1 num_blocks += 1
return recall / num_blocks, auc / num_blocks, loss / num_blocks return recall / num_blocks, auc / num_blocks, loss / num_blocks
...@@ -121,7 +122,7 @@ def main(args): ...@@ -121,7 +122,7 @@ def main(args):
blk_loss = loss_fn(logits_gnn, train_label) + args.sim_weight * loss_fn(logits_sim, train_label) blk_loss = loss_fn(logits_gnn, train_label) + args.sim_weight * loss_fn(logits_sim, train_label)
tr_loss += blk_loss.item() tr_loss += blk_loss.item()
tr_recall += recall_score(train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()) tr_recall += recall_score(train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu())
tr_auc += roc_auc_score(train_label.cpu(), logits_gnn[:, 1].detach().cpu()) tr_auc += roc_auc_score(train_label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu())
tr_blk += 1 tr_blk += 1
# backward # backward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment