Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ed4134ed
Unverified
Commit
ed4134ed
authored
Jan 21, 2022
by
Zekuan (Kay) Liu
Committed by
GitHub
Jan 21, 2022
Browse files
[Example] fix auc in caregnn example (#3647)
Co-authored-by:
zhjwy9343
<
6593865@qq.com
>
parent
57476371
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
5 deletions
+7
-5
examples/pytorch/caregnn/main.py
examples/pytorch/caregnn/main.py
+4
-3
examples/pytorch/caregnn/main_sampling.py
examples/pytorch/caregnn/main_sampling.py
+3
-2
No files found.
examples/pytorch/caregnn/main.py
View file @
ed4134ed
...
@@ -3,6 +3,7 @@ import argparse
...
@@ -3,6 +3,7 @@ import argparse
import
torch
as
th
import
torch
as
th
from
model
import
CAREGNN
from
model
import
CAREGNN
import
torch.optim
as
optim
import
torch.optim
as
optim
from
torch.nn.functional
import
softmax
from
sklearn.metrics
import
recall_score
,
roc_auc_score
from
sklearn.metrics
import
recall_score
,
roc_auc_score
from
utils
import
EarlyStopping
from
utils
import
EarlyStopping
...
@@ -70,13 +71,13 @@ def main(args):
...
@@ -70,13 +71,13 @@ def main(args):
args
.
sim_weight
*
loss_fn
(
logits_sim
[
train_idx
],
labels
[
train_idx
])
args
.
sim_weight
*
loss_fn
(
logits_sim
[
train_idx
],
labels
[
train_idx
])
tr_recall
=
recall_score
(
labels
[
train_idx
].
cpu
(),
logits_gnn
.
data
[
train_idx
].
argmax
(
dim
=
1
).
cpu
())
tr_recall
=
recall_score
(
labels
[
train_idx
].
cpu
(),
logits_gnn
.
data
[
train_idx
].
argmax
(
dim
=
1
).
cpu
())
tr_auc
=
roc_auc_score
(
labels
[
train_idx
].
cpu
(),
logits_gnn
.
data
[
train_idx
][:,
1
].
cpu
())
tr_auc
=
roc_auc_score
(
labels
[
train_idx
].
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
.
data
[
train_idx
][:,
1
].
cpu
())
# validation
# validation
val_loss
=
loss_fn
(
logits_gnn
[
val_idx
],
labels
[
val_idx
])
+
\
val_loss
=
loss_fn
(
logits_gnn
[
val_idx
],
labels
[
val_idx
])
+
\
args
.
sim_weight
*
loss_fn
(
logits_sim
[
val_idx
],
labels
[
val_idx
])
args
.
sim_weight
*
loss_fn
(
logits_sim
[
val_idx
],
labels
[
val_idx
])
val_recall
=
recall_score
(
labels
[
val_idx
].
cpu
(),
logits_gnn
.
data
[
val_idx
].
argmax
(
dim
=
1
).
cpu
())
val_recall
=
recall_score
(
labels
[
val_idx
].
cpu
(),
logits_gnn
.
data
[
val_idx
].
argmax
(
dim
=
1
).
cpu
())
val_auc
=
roc_auc_score
(
labels
[
val_idx
].
cpu
(),
logits_gnn
.
data
[
val_idx
][:,
1
].
cpu
())
val_auc
=
roc_auc_score
(
labels
[
val_idx
].
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
.
data
[
val_idx
][:,
1
].
cpu
())
# backward
# backward
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
@@ -106,7 +107,7 @@ def main(args):
...
@@ -106,7 +107,7 @@ def main(args):
test_loss
=
loss_fn
(
logits_gnn
[
test_idx
],
labels
[
test_idx
])
+
\
test_loss
=
loss_fn
(
logits_gnn
[
test_idx
],
labels
[
test_idx
])
+
\
args
.
sim_weight
*
loss_fn
(
logits_sim
[
test_idx
],
labels
[
test_idx
])
args
.
sim_weight
*
loss_fn
(
logits_sim
[
test_idx
],
labels
[
test_idx
])
test_recall
=
recall_score
(
labels
[
test_idx
].
cpu
(),
logits_gnn
[
test_idx
].
argmax
(
dim
=
1
).
cpu
())
test_recall
=
recall_score
(
labels
[
test_idx
].
cpu
(),
logits_gnn
[
test_idx
].
argmax
(
dim
=
1
).
cpu
())
test_auc
=
roc_auc_score
(
labels
[
test_idx
].
cpu
(),
logits_gnn
.
data
[
test_idx
][:,
1
].
cpu
())
test_auc
=
roc_auc_score
(
labels
[
test_idx
].
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
.
data
[
test_idx
][:,
1
].
cpu
())
print
(
"Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}"
.
format
(
test_recall
,
test_auc
,
test_loss
.
item
()))
print
(
"Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}"
.
format
(
test_recall
,
test_auc
,
test_loss
.
item
()))
...
...
examples/pytorch/caregnn/main_sampling.py
View file @
ed4134ed
...
@@ -2,6 +2,7 @@ import dgl
...
@@ -2,6 +2,7 @@ import dgl
import
argparse
import
argparse
import
torch
as
th
import
torch
as
th
import
torch.optim
as
optim
import
torch.optim
as
optim
from
torch.nn.functional
import
softmax
from
sklearn.metrics
import
roc_auc_score
,
recall_score
from
sklearn.metrics
import
roc_auc_score
,
recall_score
from
utils
import
EarlyStopping
from
utils
import
EarlyStopping
...
@@ -22,7 +23,7 @@ def evaluate(model, loss_fn, dataloader, device='cpu'):
...
@@ -22,7 +23,7 @@ def evaluate(model, loss_fn, dataloader, device='cpu'):
# compute loss
# compute loss
loss
+=
loss_fn
(
logits_gnn
,
label
).
item
()
+
args
.
sim_weight
*
loss_fn
(
logits_sim
,
label
).
item
()
loss
+=
loss_fn
(
logits_gnn
,
label
).
item
()
+
args
.
sim_weight
*
loss_fn
(
logits_sim
,
label
).
item
()
recall
+=
recall_score
(
label
.
cpu
(),
logits_gnn
.
argmax
(
dim
=
1
).
detach
().
cpu
())
recall
+=
recall_score
(
label
.
cpu
(),
logits_gnn
.
argmax
(
dim
=
1
).
detach
().
cpu
())
auc
+=
roc_auc_score
(
label
.
cpu
(),
logits_gnn
[:,
1
].
detach
().
cpu
())
auc
+=
roc_auc_score
(
label
.
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
[:,
1
].
detach
().
cpu
())
num_blocks
+=
1
num_blocks
+=
1
return
recall
/
num_blocks
,
auc
/
num_blocks
,
loss
/
num_blocks
return
recall
/
num_blocks
,
auc
/
num_blocks
,
loss
/
num_blocks
...
@@ -121,7 +122,7 @@ def main(args):
...
@@ -121,7 +122,7 @@ def main(args):
blk_loss
=
loss_fn
(
logits_gnn
,
train_label
)
+
args
.
sim_weight
*
loss_fn
(
logits_sim
,
train_label
)
blk_loss
=
loss_fn
(
logits_gnn
,
train_label
)
+
args
.
sim_weight
*
loss_fn
(
logits_sim
,
train_label
)
tr_loss
+=
blk_loss
.
item
()
tr_loss
+=
blk_loss
.
item
()
tr_recall
+=
recall_score
(
train_label
.
cpu
(),
logits_gnn
.
argmax
(
dim
=
1
).
detach
().
cpu
())
tr_recall
+=
recall_score
(
train_label
.
cpu
(),
logits_gnn
.
argmax
(
dim
=
1
).
detach
().
cpu
())
tr_auc
+=
roc_auc_score
(
train_label
.
cpu
(),
logits_gnn
[:,
1
].
detach
().
cpu
())
tr_auc
+=
roc_auc_score
(
train_label
.
cpu
(),
softmax
(
logits_gnn
,
dim
=
1
)
[:,
1
].
detach
().
cpu
())
tr_blk
+=
1
tr_blk
+=
1
# backward
# backward
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment