Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3e26c3d1
Unverified
Commit
3e26c3d1
authored
Jul 08, 2022
by
Xin Yao
Committed by
GitHub
Jul 08, 2022
Browse files
fix caregnn (#4211)
Co-authored-by:
Mufei Li
<
mufeili1996@gmail.com
>
parent
52d43127
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
27 deletions
+27
-27
examples/pytorch/caregnn/main_sampling.py
examples/pytorch/caregnn/main_sampling.py
+11
-25
examples/pytorch/caregnn/model_sampling.py
examples/pytorch/caregnn/model_sampling.py
+16
-2
No files found.
examples/pytorch/caregnn/main_sampling.py
View file @
3e26c3d1
...
@@ -92,9 +92,10 @@ def main(args):
...
@@ -92,9 +92,10 @@ def main(args):
graph
.
ndata
[
'nd'
]
=
th
.
tanh
(
model
.
layers
[
i
].
MLP
(
layers_feat
[
i
]))
graph
.
ndata
[
'nd'
]
=
th
.
tanh
(
model
.
layers
[
i
].
MLP
(
layers_feat
[
i
]))
for
etype
in
graph
.
canonical_etypes
:
for
etype
in
graph
.
canonical_etypes
:
graph
.
apply_edges
(
_l1_dist
,
etype
=
etype
)
graph
.
apply_edges
(
_l1_dist
,
etype
=
etype
)
dist
[
etype
]
=
graph
.
edges
[
etype
].
data
[
'ed'
]
dist
[
etype
]
=
graph
.
edges
[
etype
].
data
.
pop
(
'ed'
).
detach
().
cpu
()
dists
.
append
(
dist
)
dists
.
append
(
dist
)
p
.
append
(
model
.
layers
[
i
].
p
)
p
.
append
(
model
.
layers
[
i
].
p
)
graph
.
ndata
.
pop
(
'nd'
)
sampler
=
CARESampler
(
p
,
dists
,
args
.
num_layers
)
sampler
=
CARESampler
(
p
,
dists
,
args
.
num_layers
)
# train
# train
...
@@ -103,14 +104,9 @@ def main(args):
...
@@ -103,14 +104,9 @@ def main(args):
tr_recall
=
0
tr_recall
=
0
tr_auc
=
0
tr_auc
=
0
tr_blk
=
0
tr_blk
=
0
train_dataloader
=
dgl
.
dataloading
.
DataLoader
(
graph
,
train_dataloader
=
dgl
.
dataloading
.
DataLoader
(
train_idx
,
graph
,
train_idx
,
sampler
,
batch_size
=
args
.
batch_size
,
sampler
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
)
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
)
for
input_nodes
,
output_nodes
,
blocks
in
train_dataloader
:
for
input_nodes
,
output_nodes
,
blocks
in
train_dataloader
:
blocks
=
[
b
.
to
(
device
)
for
b
in
blocks
]
blocks
=
[
b
.
to
(
device
)
for
b
in
blocks
]
...
@@ -135,14 +131,9 @@ def main(args):
...
@@ -135,14 +131,9 @@ def main(args):
# validation
# validation
model
.
eval
()
model
.
eval
()
val_dataloader
=
dgl
.
dataloading
.
DataLoader
(
graph
,
val_dataloader
=
dgl
.
dataloading
.
DataLoader
(
val_idx
,
graph
,
val_idx
,
sampler
,
batch_size
=
args
.
batch_size
,
sampler
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
)
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
)
val_recall
,
val_auc
,
val_loss
=
evaluate
(
model
,
loss_fn
,
val_dataloader
,
device
)
val_recall
,
val_auc
,
val_loss
=
evaluate
(
model
,
loss_fn
,
val_dataloader
,
device
)
...
@@ -159,14 +150,9 @@ def main(args):
...
@@ -159,14 +150,9 @@ def main(args):
model
.
eval
()
model
.
eval
()
if
args
.
early_stop
:
if
args
.
early_stop
:
model
.
load_state_dict
(
th
.
load
(
'es_checkpoint.pt'
))
model
.
load_state_dict
(
th
.
load
(
'es_checkpoint.pt'
))
test_dataloader
=
dgl
.
dataloading
.
DataLoader
(
graph
,
test_dataloader
=
dgl
.
dataloading
.
DataLoader
(
test_idx
,
graph
,
test_idx
,
sampler
,
batch_size
=
args
.
batch_size
,
sampler
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
)
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
args
.
num_workers
)
test_recall
,
test_auc
,
test_loss
=
evaluate
(
model
,
loss_fn
,
test_dataloader
,
device
)
test_recall
,
test_auc
,
test_loss
=
evaluate
(
model
,
loss_fn
,
test_dataloader
,
device
)
...
...
examples/pytorch/caregnn/model_sampling.py
View file @
3e26c3d1
...
@@ -13,9 +13,10 @@ def _l1_dist(edges):
...
@@ -13,9 +13,10 @@ def _l1_dist(edges):
class
CARESampler
(
dgl
.
dataloading
.
BlockSampler
):
class
CARESampler
(
dgl
.
dataloading
.
BlockSampler
):
def
__init__
(
self
,
p
,
dists
,
num_layers
):
def
__init__
(
self
,
p
,
dists
,
num_layers
):
super
().
__init__
(
num_layers
)
super
().
__init__
()
self
.
p
=
p
self
.
p
=
p
self
.
dists
=
dists
self
.
dists
=
dists
self
.
num_layers
=
num_layers
def
sample_frontier
(
self
,
block_id
,
g
,
seed_nodes
,
*
args
,
**
kwargs
):
def
sample_frontier
(
self
,
block_id
,
g
,
seed_nodes
,
*
args
,
**
kwargs
):
with
g
.
local_scope
():
with
g
.
local_scope
():
...
@@ -28,7 +29,7 @@ class CARESampler(dgl.dataloading.BlockSampler):
...
@@ -28,7 +29,7 @@ class CARESampler(dgl.dataloading.BlockSampler):
num_neigh
=
th
.
ceil
(
g
.
in_degrees
(
node
,
etype
=
etype
)
*
self
.
p
[
block_id
][
etype
]).
int
().
item
()
num_neigh
=
th
.
ceil
(
g
.
in_degrees
(
node
,
etype
=
etype
)
*
self
.
p
[
block_id
][
etype
]).
int
().
item
()
neigh_dist
=
self
.
dists
[
block_id
][
etype
][
edges
]
neigh_dist
=
self
.
dists
[
block_id
][
etype
][
edges
]
if
neigh_dist
.
shape
[
0
]
>
num_neigh
:
if
neigh_dist
.
shape
[
0
]
>
num_neigh
:
neigh_index
=
np
.
argpartition
(
neigh_dist
.
cpu
().
detach
()
,
num_neigh
)[:
num_neigh
]
neigh_index
=
np
.
argpartition
(
neigh_dist
,
num_neigh
)[:
num_neigh
]
else
:
else
:
neigh_index
=
np
.
arange
(
num_neigh
)
neigh_index
=
np
.
arange
(
num_neigh
)
edge_mask
[
edges
[
neigh_index
]]
=
1
edge_mask
[
edges
[
neigh_index
]]
=
1
...
@@ -36,6 +37,19 @@ class CARESampler(dgl.dataloading.BlockSampler):
...
@@ -36,6 +37,19 @@ class CARESampler(dgl.dataloading.BlockSampler):
return
dgl
.
edge_subgraph
(
g
,
new_edges_masks
,
relabel_nodes
=
False
)
return
dgl
.
edge_subgraph
(
g
,
new_edges_masks
,
relabel_nodes
=
False
)
def
sample_blocks
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
output_nodes
=
seed_nodes
blocks
=
[]
for
block_id
in
reversed
(
range
(
self
.
num_layers
)):
frontier
=
self
.
sample_frontier
(
block_id
,
g
,
seed_nodes
)
eid
=
frontier
.
edata
[
dgl
.
EID
]
block
=
dgl
.
to_block
(
frontier
,
seed_nodes
)
block
.
edata
[
dgl
.
EID
]
=
eid
seed_nodes
=
block
.
srcdata
[
dgl
.
NID
]
blocks
.
insert
(
0
,
block
)
return
seed_nodes
,
output_nodes
,
blocks
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
num_layers
return
self
.
num_layers
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment