Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ed4134ed
Unverified
Commit
ed4134ed
authored
Jan 21, 2022
by
Zekuan (Kay) Liu
Committed by
GitHub
Jan 21, 2022
Browse files
[Example] fix auc in caregnn example (#3647)
Co-authored-by:
zhjwy9343
<
6593865@qq.com
>
parent
57476371
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
5 deletions
+7
-5
examples/pytorch/caregnn/main.py
examples/pytorch/caregnn/main.py
+4
-3
examples/pytorch/caregnn/main_sampling.py
examples/pytorch/caregnn/main_sampling.py
+3
-2
No files found.
examples/pytorch/caregnn/main.py
View file @
ed4134ed
...
...
@@ -3,6 +3,7 @@ import argparse
import
torch
as
th
from
model
import
CAREGNN
import
torch.optim
as
optim
from
torch.nn.functional
import
softmax
from
sklearn.metrics
import
recall_score
,
roc_auc_score
from
utils
import
EarlyStopping
...
...
@@ -70,13 +71,13 @@ def main(args):
args
.
sim_weight
*
loss_fn
(
logits_sim
[
train_idx
],
labels
[
train_idx
])
tr_recall
=
recall_score
(
labels
[
train_idx
].
cpu
(),
logits_gnn
.
data
[
train_idx
].
argmax
(
dim
=
1
).
cpu
())
tr_auc
=
roc_auc_score
(
labels
[
train_idx
].
cpu
(),
logits_gnn
.
data
[
train_idx
][:,
1
].
cpu
())
tr_auc
=
roc_auc_score
(
labels
[
train_idx
].
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
.
data
[
train_idx
][:,
1
].
cpu
())
# validation
val_loss
=
loss_fn
(
logits_gnn
[
val_idx
],
labels
[
val_idx
])
+
\
args
.
sim_weight
*
loss_fn
(
logits_sim
[
val_idx
],
labels
[
val_idx
])
val_recall
=
recall_score
(
labels
[
val_idx
].
cpu
(),
logits_gnn
.
data
[
val_idx
].
argmax
(
dim
=
1
).
cpu
())
val_auc
=
roc_auc_score
(
labels
[
val_idx
].
cpu
(),
logits_gnn
.
data
[
val_idx
][:,
1
].
cpu
())
val_auc
=
roc_auc_score
(
labels
[
val_idx
].
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
.
data
[
val_idx
][:,
1
].
cpu
())
# backward
optimizer
.
zero_grad
()
...
...
@@ -106,7 +107,7 @@ def main(args):
test_loss
=
loss_fn
(
logits_gnn
[
test_idx
],
labels
[
test_idx
])
+
\
args
.
sim_weight
*
loss_fn
(
logits_sim
[
test_idx
],
labels
[
test_idx
])
test_recall
=
recall_score
(
labels
[
test_idx
].
cpu
(),
logits_gnn
[
test_idx
].
argmax
(
dim
=
1
).
cpu
())
test_auc
=
roc_auc_score
(
labels
[
test_idx
].
cpu
(),
logits_gnn
.
data
[
test_idx
][:,
1
].
cpu
())
test_auc
=
roc_auc_score
(
labels
[
test_idx
].
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
.
data
[
test_idx
][:,
1
].
cpu
())
print
(
"Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}"
.
format
(
test_recall
,
test_auc
,
test_loss
.
item
()))
...
...
examples/pytorch/caregnn/main_sampling.py
View file @
ed4134ed
...
...
@@ -2,6 +2,7 @@ import dgl
import
argparse
import
torch
as
th
import
torch.optim
as
optim
from
torch.nn.functional
import
softmax
from
sklearn.metrics
import
roc_auc_score
,
recall_score
from
utils
import
EarlyStopping
...
...
@@ -22,7 +23,7 @@ def evaluate(model, loss_fn, dataloader, device='cpu'):
# compute loss
loss
+=
loss_fn
(
logits_gnn
,
label
).
item
()
+
args
.
sim_weight
*
loss_fn
(
logits_sim
,
label
).
item
()
recall
+=
recall_score
(
label
.
cpu
(),
logits_gnn
.
argmax
(
dim
=
1
).
detach
().
cpu
())
auc
+=
roc_auc_score
(
label
.
cpu
(),
logits_gnn
[:,
1
].
detach
().
cpu
())
auc
+=
roc_auc_score
(
label
.
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
[:,
1
].
detach
().
cpu
())
num_blocks
+=
1
return
recall
/
num_blocks
,
auc
/
num_blocks
,
loss
/
num_blocks
...
...
@@ -121,7 +122,7 @@ def main(args):
blk_loss
=
loss_fn
(
logits_gnn
,
train_label
)
+
args
.
sim_weight
*
loss_fn
(
logits_sim
,
train_label
)
tr_loss
+=
blk_loss
.
item
()
tr_recall
+=
recall_score
(
train_label
.
cpu
(),
logits_gnn
.
argmax
(
dim
=
1
).
detach
().
cpu
())
tr_auc
+=
roc_auc_score
(
train_label
.
cpu
(),
logits_gnn
[:,
1
].
detach
().
cpu
())
tr_auc
+=
roc_auc_score
(
train_label
.
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
[:,
1
].
detach
().
cpu
())
tr_blk
+=
1
# backward
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment