"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5588725e8e7be497839432e5328c596169385f16"
Unverified Commit ed4134ed authored by Zekuan (Kay) Liu's avatar Zekuan (Kay) Liu Committed by GitHub
Browse files

[Example] fix auc in caregnn example (#3647)


Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent 57476371
......@@ -3,6 +3,7 @@ import argparse
import torch as th
from model import CAREGNN
import torch.optim as optim
from torch.nn.functional import softmax
from sklearn.metrics import recall_score, roc_auc_score
from utils import EarlyStopping
......@@ -70,13 +71,13 @@ def main(args):
args.sim_weight * loss_fn(logits_sim[train_idx], labels[train_idx])
tr_recall = recall_score(labels[train_idx].cpu(), logits_gnn.data[train_idx].argmax(dim=1).cpu())
tr_auc = roc_auc_score(labels[train_idx].cpu(), logits_gnn.data[train_idx][:, 1].cpu())
tr_auc = roc_auc_score(labels[train_idx].cpu(), softmax(logits_gnn, dim=1).data[train_idx][:, 1].cpu())
# validation
val_loss = loss_fn(logits_gnn[val_idx], labels[val_idx]) + \
args.sim_weight * loss_fn(logits_sim[val_idx], labels[val_idx])
val_recall = recall_score(labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu())
val_auc = roc_auc_score(labels[val_idx].cpu(), logits_gnn.data[val_idx][:, 1].cpu())
val_auc = roc_auc_score(labels[val_idx].cpu(), softmax(logits_gnn, dim=1).data[val_idx][:, 1].cpu())
# backward
optimizer.zero_grad()
......@@ -106,7 +107,7 @@ def main(args):
test_loss = loss_fn(logits_gnn[test_idx], labels[test_idx]) + \
args.sim_weight * loss_fn(logits_sim[test_idx], labels[test_idx])
test_recall = recall_score(labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu())
test_auc = roc_auc_score(labels[test_idx].cpu(), logits_gnn.data[test_idx][:, 1].cpu())
test_auc = roc_auc_score(labels[test_idx].cpu(), softmax(logits_gnn, dim=1).data[test_idx][:, 1].cpu())
print("Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}".format(test_recall, test_auc, test_loss.item()))
......
......@@ -2,6 +2,7 @@ import dgl
import argparse
import torch as th
import torch.optim as optim
from torch.nn.functional import softmax
from sklearn.metrics import roc_auc_score, recall_score
from utils import EarlyStopping
......@@ -22,7 +23,7 @@ def evaluate(model, loss_fn, dataloader, device='cpu'):
# compute loss
loss += loss_fn(logits_gnn, label).item() + args.sim_weight * loss_fn(logits_sim, label).item()
recall += recall_score(label.cpu(), logits_gnn.argmax(dim=1).detach().cpu())
auc += roc_auc_score(label.cpu(), logits_gnn[:, 1].detach().cpu())
auc += roc_auc_score(label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu())
num_blocks += 1
return recall / num_blocks, auc / num_blocks, loss / num_blocks
......@@ -121,7 +122,7 @@ def main(args):
blk_loss = loss_fn(logits_gnn, train_label) + args.sim_weight * loss_fn(logits_sim, train_label)
tr_loss += blk_loss.item()
tr_recall += recall_score(train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu())
tr_auc += roc_auc_score(train_label.cpu(), logits_gnn[:, 1].detach().cpu())
tr_auc += roc_auc_score(train_label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu())
tr_blk += 1
# backward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment