Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9836f78e
Unverified
Commit
9836f78e
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
autoformat (#5322)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
704bcaf6
Changes
57
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
201 additions
and
142 deletions
+201
-142
benchmarks/benchmarks/api/bench_nn_graphconv.py
benchmarks/benchmarks/api/bench_nn_graphconv.py
+3
-3
benchmarks/benchmarks/api/bench_nn_heterographconv.py
benchmarks/benchmarks/api/bench_nn_heterographconv.py
+3
-3
benchmarks/benchmarks/api/bench_node_subgraph.py
benchmarks/benchmarks/api/bench_node_subgraph.py
+3
-3
benchmarks/benchmarks/api/bench_random_walk.py
benchmarks/benchmarks/api/bench_random_walk.py
+2
-2
benchmarks/benchmarks/api/bench_readout.py
benchmarks/benchmarks/api/bench_readout.py
+2
-2
benchmarks/benchmarks/api/bench_reverse.py
benchmarks/benchmarks/api/bench_reverse.py
+2
-2
benchmarks/benchmarks/api/bench_sample_neighbors.py
benchmarks/benchmarks/api/bench_sample_neighbors.py
+3
-3
benchmarks/benchmarks/api/bench_to_block.py
benchmarks/benchmarks/api/bench_to_block.py
+2
-2
benchmarks/benchmarks/api/bench_udf_apply_edges.py
benchmarks/benchmarks/api/bench_udf_apply_edges.py
+3
-3
benchmarks/benchmarks/api/bench_udf_multi_update_all.py
benchmarks/benchmarks/api/bench_udf_multi_update_all.py
+3
-3
benchmarks/benchmarks/api/bench_udf_update_all.py
benchmarks/benchmarks/api/bench_udf_update_all.py
+3
-3
benchmarks/benchmarks/api/bench_unbatch.py
benchmarks/benchmarks/api/bench_unbatch.py
+2
-2
benchmarks/benchmarks/kernel/bench_edgesoftmax.py
benchmarks/benchmarks/kernel/bench_edgesoftmax.py
+2
-2
benchmarks/benchmarks/kernel/bench_gsddmm_u_dot_v.py
benchmarks/benchmarks/kernel/bench_gsddmm_u_dot_v.py
+2
-2
benchmarks/benchmarks/kernel/bench_gspmm_copy_u.py
benchmarks/benchmarks/kernel/bench_gspmm_copy_u.py
+2
-2
benchmarks/benchmarks/kernel/bench_gspmm_u_mul_e_sum.py
benchmarks/benchmarks/kernel/bench_gspmm_u_mul_e_sum.py
+2
-2
benchmarks/benchmarks/model_acc/bench_gat.py
benchmarks/benchmarks/model_acc/bench_gat.py
+1
-2
benchmarks/benchmarks/model_acc/bench_gcn.py
benchmarks/benchmarks/model_acc/bench_gcn.py
+1
-2
benchmarks/benchmarks/model_acc/bench_gcn_udf.py
benchmarks/benchmarks/model_acc/bench_gcn_udf.py
+1
-2
benchmarks/benchmarks/model_acc/bench_rgcn_ns.py
benchmarks/benchmarks/model_acc/bench_rgcn_ns.py
+159
-97
No files found.
benchmarks/benchmarks/api/bench_nn_graphconv.py
View file @
9836f78e
import
time
import
time
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
dgl.nn.pytorch
import
SAGEConv
from
dgl.nn.pytorch
import
SAGEConv
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_nn_heterographconv.py
View file @
9836f78e
import
time
import
time
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
dgl.nn.pytorch
import
HeteroGraphConv
,
SAGEConv
from
dgl.nn.pytorch
import
HeteroGraphConv
,
SAGEConv
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_node_subgraph.py
View file @
9836f78e
import
time
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl
import
dgl.function
as
fn
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_random_walk.py
View file @
9836f78e
import
time
import
time
import
torch
import
dgl
import
dgl
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_readout.py
View file @
9836f78e
import
time
import
time
import
torch
import
dgl
import
dgl
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_reverse.py
View file @
9836f78e
import
time
import
time
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
dgl
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_sample_neighbors.py
View file @
9836f78e
import
time
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl
import
dgl.function
as
fn
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_to_block.py
View file @
9836f78e
import
time
import
time
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
dgl
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_udf_apply_edges.py
View file @
9836f78e
import
time
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl
import
dgl.function
as
fn
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_udf_multi_update_all.py
View file @
9836f78e
import
time
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl
import
dgl.function
as
fn
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_udf_update_all.py
View file @
9836f78e
import
time
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl
import
dgl.function
as
fn
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_unbatch.py
View file @
9836f78e
import
time
import
time
import
torch
import
dgl
import
dgl
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/kernel/bench_edgesoftmax.py
View file @
9836f78e
import
time
import
time
import
torch
import
dgl
import
dgl
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/kernel/bench_gsddmm_u_dot_v.py
View file @
9836f78e
import
time
import
time
import
torch
import
dgl
import
dgl
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/kernel/bench_gspmm_copy_u.py
View file @
9836f78e
import
time
import
time
import
torch
import
dgl
import
dgl
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/kernel/bench_gspmm_u_mul_e_sum.py
View file @
9836f78e
import
time
import
time
import
torch
import
dgl
import
dgl
import
torch
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/model_acc/bench_gat.py
View file @
9836f78e
import
dgl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
from
dgl.nn.pytorch
import
GATConv
from
dgl.nn.pytorch
import
GATConv
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/model_acc/bench_gcn.py
View file @
9836f78e
import
dgl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
from
dgl.nn.pytorch
import
GraphConv
from
dgl.nn.pytorch
import
GraphConv
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/model_acc/bench_gcn_udf.py
View file @
9836f78e
import
dgl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
from
..
import
utils
from
..
import
utils
...
...
benchmarks/benchmarks/model_acc/bench_rgcn_ns.py
View file @
9836f78e
import
dgl
import
itertools
import
itertools
import
time
import
dgl
import
dgl.nn.pytorch
as
dglnn
import
torch
as
th
import
torch
as
th
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torch.multiprocessing
as
mp
from
torch.utils.data
import
DataLoader
import
dgl.nn.pytorch
as
dglnn
from
dgl.nn
import
RelGraphConv
from
dgl.nn
import
RelGraphConv
import
time
from
torch.utils.data
import
DataLoader
from
..
import
utils
from
..
import
utils
class
EntityClassify
(
nn
.
Module
):
class
EntityClassify
(
nn
.
Module
):
"""
Entity classification class for RGCN
"""Entity classification class for RGCN
Parameters
Parameters
----------
----------
device : int
device : int
...
@@ -35,7 +37,9 @@ class EntityClassify(nn.Module):
...
@@ -35,7 +37,9 @@ class EntityClassify(nn.Module):
use_self_loop : bool
use_self_loop : bool
Use self loop if True, default False.
Use self loop if True, default False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
device
,
device
,
num_nodes
,
num_nodes
,
h_dim
,
h_dim
,
...
@@ -45,7 +49,8 @@ class EntityClassify(nn.Module):
...
@@ -45,7 +49,8 @@ class EntityClassify(nn.Module):
num_hidden_layers
=
1
,
num_hidden_layers
=
1
,
dropout
=
0
,
dropout
=
0
,
use_self_loop
=
False
,
use_self_loop
=
False
,
layer_norm
=
False
):
layer_norm
=
False
,
):
super
(
EntityClassify
,
self
).
__init__
()
super
(
EntityClassify
,
self
).
__init__
()
self
.
device
=
device
self
.
device
=
device
self
.
num_nodes
=
num_nodes
self
.
num_nodes
=
num_nodes
...
@@ -60,22 +65,47 @@ class EntityClassify(nn.Module):
...
@@ -60,22 +65,47 @@ class EntityClassify(nn.Module):
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
=
nn
.
ModuleList
()
# i2h
# i2h
self
.
layers
.
append
(
RelGraphConv
(
self
.
layers
.
append
(
self
.
h_dim
,
self
.
h_dim
,
self
.
num_rels
,
"basis"
,
RelGraphConv
(
self
.
num_bases
,
activation
=
F
.
relu
,
self_loop
=
self
.
use_self_loop
,
self
.
h_dim
,
dropout
=
self
.
dropout
,
layer_norm
=
layer_norm
))
self
.
h_dim
,
self
.
num_rels
,
"basis"
,
self
.
num_bases
,
activation
=
F
.
relu
,
self_loop
=
self
.
use_self_loop
,
dropout
=
self
.
dropout
,
layer_norm
=
layer_norm
,
)
)
# h2h
# h2h
for
idx
in
range
(
self
.
num_hidden_layers
):
for
idx
in
range
(
self
.
num_hidden_layers
):
self
.
layers
.
append
(
RelGraphConv
(
self
.
layers
.
append
(
self
.
h_dim
,
self
.
h_dim
,
self
.
num_rels
,
"basis"
,
RelGraphConv
(
self
.
num_bases
,
activation
=
F
.
relu
,
self_loop
=
self
.
use_self_loop
,
self
.
h_dim
,
dropout
=
self
.
dropout
,
layer_norm
=
layer_norm
))
self
.
h_dim
,
self
.
num_rels
,
"basis"
,
self
.
num_bases
,
activation
=
F
.
relu
,
self_loop
=
self
.
use_self_loop
,
dropout
=
self
.
dropout
,
layer_norm
=
layer_norm
,
)
)
# h2o
# h2o
self
.
layers
.
append
(
RelGraphConv
(
self
.
layers
.
append
(
self
.
h_dim
,
self
.
out_dim
,
self
.
num_rels
,
"basis"
,
RelGraphConv
(
self
.
num_bases
,
activation
=
None
,
self
.
h_dim
,
self
.
out_dim
,
self
.
num_rels
,
"basis"
,
self
.
num_bases
,
activation
=
None
,
self_loop
=
self
.
use_self_loop
,
self_loop
=
self
.
use_self_loop
,
layer_norm
=
layer_norm
))
layer_norm
=
layer_norm
,
)
)
def
forward
(
self
,
blocks
,
feats
,
norm
=
None
):
def
forward
(
self
,
blocks
,
feats
,
norm
=
None
):
if
blocks
is
None
:
if
blocks
is
None
:
...
@@ -84,9 +114,10 @@ class EntityClassify(nn.Module):
...
@@ -84,9 +114,10 @@ class EntityClassify(nn.Module):
h
=
feats
h
=
feats
for
layer
,
block
in
zip
(
self
.
layers
,
blocks
):
for
layer
,
block
in
zip
(
self
.
layers
,
blocks
):
block
=
block
.
to
(
self
.
device
)
block
=
block
.
to
(
self
.
device
)
h
=
layer
(
block
,
h
,
block
.
edata
[
'
etype
'
],
block
.
edata
[
'
norm
'
])
h
=
layer
(
block
,
h
,
block
.
edata
[
"
etype
"
],
block
.
edata
[
"
norm
"
])
return
h
return
h
class
RelGraphEmbedLayer
(
nn
.
Module
):
class
RelGraphEmbedLayer
(
nn
.
Module
):
r
"""Embedding layer for featureless heterograph.
r
"""Embedding layer for featureless heterograph.
Parameters
Parameters
...
@@ -107,7 +138,9 @@ class RelGraphEmbedLayer(nn.Module):
...
@@ -107,7 +138,9 @@ class RelGraphEmbedLayer(nn.Module):
embed_name : str, optional
embed_name : str, optional
Embed name
Embed name
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
device
,
device
,
num_nodes
,
num_nodes
,
node_tids
,
node_tids
,
...
@@ -115,7 +148,8 @@ class RelGraphEmbedLayer(nn.Module):
...
@@ -115,7 +148,8 @@ class RelGraphEmbedLayer(nn.Module):
input_size
,
input_size
,
embed_size
,
embed_size
,
sparse_emb
=
False
,
sparse_emb
=
False
,
embed_name
=
'embed'
):
embed_name
=
"embed"
,
):
super
(
RelGraphEmbedLayer
,
self
).
__init__
()
super
(
RelGraphEmbedLayer
,
self
).
__init__
()
self
.
device
=
device
self
.
device
=
device
self
.
embed_size
=
embed_size
self
.
embed_size
=
embed_size
...
@@ -135,7 +169,9 @@ class RelGraphEmbedLayer(nn.Module):
...
@@ -135,7 +169,9 @@ class RelGraphEmbedLayer(nn.Module):
nn
.
init
.
xavier_uniform_
(
embed
)
nn
.
init
.
xavier_uniform_
(
embed
)
self
.
embeds
[
str
(
ntype
)]
=
embed
self
.
embeds
[
str
(
ntype
)]
=
embed
self
.
node_embeds
=
th
.
nn
.
Embedding
(
node_tids
.
shape
[
0
],
self
.
embed_size
,
sparse
=
self
.
sparse_emb
)
self
.
node_embeds
=
th
.
nn
.
Embedding
(
node_tids
.
shape
[
0
],
self
.
embed_size
,
sparse
=
self
.
sparse_emb
)
nn
.
init
.
uniform_
(
self
.
node_embeds
.
weight
,
-
1.0
,
1.0
)
nn
.
init
.
uniform_
(
self
.
node_embeds
.
weight
,
-
1.0
,
1.0
)
def
forward
(
self
,
node_ids
,
node_tids
,
type_ids
,
features
):
def
forward
(
self
,
node_ids
,
node_tids
,
type_ids
,
features
):
...
@@ -157,17 +193,22 @@ class RelGraphEmbedLayer(nn.Module):
...
@@ -157,17 +193,22 @@ class RelGraphEmbedLayer(nn.Module):
embeddings as the input of the next layer
embeddings as the input of the next layer
"""
"""
tsd_ids
=
node_ids
.
to
(
self
.
node_embeds
.
weight
.
device
)
tsd_ids
=
node_ids
.
to
(
self
.
node_embeds
.
weight
.
device
)
embeds
=
th
.
empty
(
node_ids
.
shape
[
0
],
self
.
embed_size
,
device
=
self
.
device
)
embeds
=
th
.
empty
(
node_ids
.
shape
[
0
],
self
.
embed_size
,
device
=
self
.
device
)
for
ntype
in
range
(
self
.
num_of_ntype
):
for
ntype
in
range
(
self
.
num_of_ntype
):
if
features
[
ntype
]
is
not
None
:
if
features
[
ntype
]
is
not
None
:
loc
=
node_tids
==
ntype
loc
=
node_tids
==
ntype
embeds
[
loc
]
=
features
[
ntype
][
type_ids
[
loc
]].
to
(
self
.
device
)
@
self
.
embeds
[
str
(
ntype
)].
to
(
self
.
device
)
embeds
[
loc
]
=
features
[
ntype
][
type_ids
[
loc
]].
to
(
self
.
device
)
@
self
.
embeds
[
str
(
ntype
)].
to
(
self
.
device
)
else
:
else
:
loc
=
node_tids
==
ntype
loc
=
node_tids
==
ntype
embeds
[
loc
]
=
self
.
node_embeds
(
tsd_ids
[
loc
]).
to
(
self
.
device
)
embeds
[
loc
]
=
self
.
node_embeds
(
tsd_ids
[
loc
]).
to
(
self
.
device
)
return
embeds
return
embeds
def
evaluate
(
model
,
embed_layer
,
eval_loader
,
node_feats
):
def
evaluate
(
model
,
embed_layer
,
eval_loader
,
node_feats
):
model
.
eval
()
model
.
eval
()
embed_layer
.
eval
()
embed_layer
.
eval
()
...
@@ -178,36 +219,39 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
...
@@ -178,36 +219,39 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
for
sample_data
in
eval_loader
:
for
sample_data
in
eval_loader
:
th
.
cuda
.
empty_cache
()
th
.
cuda
.
empty_cache
()
_
,
_
,
blocks
=
sample_data
_
,
_
,
blocks
=
sample_data
feats
=
embed_layer
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
feats
=
embed_layer
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
blocks
[
0
].
srcdata
[
dgl
.
NTYPE
],
blocks
[
0
].
srcdata
[
dgl
.
NTYPE
],
blocks
[
0
].
srcdata
[
'type_id'
],
blocks
[
0
].
srcdata
[
"type_id"
],
node_feats
)
node_feats
,
)
logits
=
model
(
blocks
,
feats
)
logits
=
model
(
blocks
,
feats
)
eval_logits
.
append
(
logits
.
cpu
().
detach
())
eval_logits
.
append
(
logits
.
cpu
().
detach
())
eval_seeds
.
append
(
blocks
[
-
1
].
dstdata
[
'
type_id
'
].
cpu
().
detach
())
eval_seeds
.
append
(
blocks
[
-
1
].
dstdata
[
"
type_id
"
].
cpu
().
detach
())
eval_logits
=
th
.
cat
(
eval_logits
)
eval_logits
=
th
.
cat
(
eval_logits
)
eval_seeds
=
th
.
cat
(
eval_seeds
)
eval_seeds
=
th
.
cat
(
eval_seeds
)
return
eval_logits
,
eval_seeds
return
eval_logits
,
eval_seeds
@
utils
.
benchmark
(
'acc'
,
timeout
=
3600
)
# ogbn-mag takes ~1 hour to train
@
utils
.
parametrize
(
'data'
,
[
'am'
,
'ogbn-mag'
])
@
utils
.
benchmark
(
"acc"
,
timeout
=
3600
)
# ogbn-mag takes ~1 hour to train
@
utils
.
parametrize
(
"data"
,
[
"am"
,
"ogbn-mag"
])
def
track_acc
(
data
):
def
track_acc
(
data
):
dataset
=
utils
.
process_data
(
data
)
dataset
=
utils
.
process_data
(
data
)
device
=
utils
.
get_bench_device
()
device
=
utils
.
get_bench_device
()
if
data
==
'
am
'
:
if
data
==
"
am
"
:
n_bases
=
40
n_bases
=
40
l2norm
=
5e-4
l2norm
=
5e-4
n_epochs
=
20
n_epochs
=
20
elif
data
==
'
ogbn-mag
'
:
elif
data
==
"
ogbn-mag
"
:
n_bases
=
2
n_bases
=
2
l2norm
=
0
l2norm
=
0
n_epochs
=
20
n_epochs
=
20
else
:
else
:
raise
ValueError
()
raise
ValueError
()
fanouts
=
[
25
,
15
]
fanouts
=
[
25
,
15
]
n_layers
=
2
n_layers
=
2
batch_size
=
1024
batch_size
=
1024
n_hidden
=
64
n_hidden
=
64
...
@@ -219,20 +263,20 @@ def track_acc(data):
...
@@ -219,20 +263,20 @@ def track_acc(data):
hg
=
dataset
[
0
]
hg
=
dataset
[
0
]
category
=
dataset
.
predict_category
category
=
dataset
.
predict_category
num_classes
=
dataset
.
num_classes
num_classes
=
dataset
.
num_classes
train_mask
=
hg
.
nodes
[
category
].
data
.
pop
(
'
train_mask
'
)
train_mask
=
hg
.
nodes
[
category
].
data
.
pop
(
"
train_mask
"
)
train_idx
=
th
.
nonzero
(
train_mask
,
as_tuple
=
False
).
squeeze
()
train_idx
=
th
.
nonzero
(
train_mask
,
as_tuple
=
False
).
squeeze
()
test_mask
=
hg
.
nodes
[
category
].
data
.
pop
(
'
test_mask
'
)
test_mask
=
hg
.
nodes
[
category
].
data
.
pop
(
"
test_mask
"
)
test_idx
=
th
.
nonzero
(
test_mask
,
as_tuple
=
False
).
squeeze
()
test_idx
=
th
.
nonzero
(
test_mask
,
as_tuple
=
False
).
squeeze
()
labels
=
hg
.
nodes
[
category
].
data
.
pop
(
'
labels
'
).
to
(
device
)
labels
=
hg
.
nodes
[
category
].
data
.
pop
(
"
labels
"
).
to
(
device
)
num_of_ntype
=
len
(
hg
.
ntypes
)
num_of_ntype
=
len
(
hg
.
ntypes
)
num_rels
=
len
(
hg
.
canonical_etypes
)
num_rels
=
len
(
hg
.
canonical_etypes
)
node_feats
=
[]
node_feats
=
[]
for
ntype
in
hg
.
ntypes
:
for
ntype
in
hg
.
ntypes
:
if
len
(
hg
.
nodes
[
ntype
].
data
)
==
0
or
'
feat
'
not
in
hg
.
nodes
[
ntype
].
data
:
if
len
(
hg
.
nodes
[
ntype
].
data
)
==
0
or
"
feat
"
not
in
hg
.
nodes
[
ntype
].
data
:
node_feats
.
append
(
None
)
node_feats
.
append
(
None
)
else
:
else
:
feat
=
hg
.
nodes
[
ntype
].
data
.
pop
(
'
feat
'
)
feat
=
hg
.
nodes
[
ntype
].
data
.
pop
(
"
feat
"
)
node_feats
.
append
(
feat
.
share_memory_
())
node_feats
.
append
(
feat
.
share_memory_
())
# get target category id
# get target category id
...
@@ -241,25 +285,27 @@ def track_acc(data):
...
@@ -241,25 +285,27 @@ def track_acc(data):
if
ntype
==
category
:
if
ntype
==
category
:
category_id
=
i
category_id
=
i
g
=
dgl
.
to_homogeneous
(
hg
)
g
=
dgl
.
to_homogeneous
(
hg
)
u
,
v
,
eid
=
g
.
all_edges
(
form
=
'
all
'
)
u
,
v
,
eid
=
g
.
all_edges
(
form
=
"
all
"
)
# global norm
# global norm
_
,
inverse_index
,
count
=
th
.
unique
(
v
,
return_inverse
=
True
,
return_counts
=
True
)
_
,
inverse_index
,
count
=
th
.
unique
(
v
,
return_inverse
=
True
,
return_counts
=
True
)
degrees
=
count
[
inverse_index
]
degrees
=
count
[
inverse_index
]
norm
=
th
.
ones
(
eid
.
shape
[
0
])
/
degrees
norm
=
th
.
ones
(
eid
.
shape
[
0
])
/
degrees
norm
=
norm
.
unsqueeze
(
1
)
norm
=
norm
.
unsqueeze
(
1
)
g
.
edata
[
'
norm
'
]
=
norm
g
.
edata
[
"
norm
"
]
=
norm
g
.
edata
[
'
etype
'
]
=
g
.
edata
[
dgl
.
ETYPE
]
g
.
edata
[
"
etype
"
]
=
g
.
edata
[
dgl
.
ETYPE
]
g
.
ndata
[
'
type_id
'
]
=
g
.
ndata
[
dgl
.
NID
]
g
.
ndata
[
"
type_id
"
]
=
g
.
ndata
[
dgl
.
NID
]
g
.
ndata
[
'
ntype
'
]
=
g
.
ndata
[
dgl
.
NTYPE
]
g
.
ndata
[
"
ntype
"
]
=
g
.
ndata
[
dgl
.
NTYPE
]
node_ids
=
th
.
arange
(
g
.
number_of_nodes
())
node_ids
=
th
.
arange
(
g
.
number_of_nodes
())
# find out the target node ids
# find out the target node ids
node_tids
=
g
.
ndata
[
dgl
.
NTYPE
]
node_tids
=
g
.
ndata
[
dgl
.
NTYPE
]
loc
=
(
node_tids
==
category_id
)
loc
=
node_tids
==
category_id
target_nids
=
node_ids
[
loc
]
target_nids
=
node_ids
[
loc
]
g
=
g
.
formats
(
'
csc
'
)
g
=
g
.
formats
(
"
csc
"
)
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
fanouts
)
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
fanouts
)
train_loader
=
dgl
.
dataloading
.
DataLoader
(
train_loader
=
dgl
.
dataloading
.
DataLoader
(
g
,
g
,
...
@@ -268,7 +314,8 @@ def track_acc(data):
...
@@ -268,7 +314,8 @@ def track_acc(data):
batch_size
=
batch_size
,
batch_size
=
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
,
drop_last
=
False
,
num_workers
=
num_workers
)
num_workers
=
num_workers
,
)
test_loader
=
dgl
.
dataloading
.
DataLoader
(
test_loader
=
dgl
.
dataloading
.
DataLoader
(
g
,
g
,
target_nids
[
test_idx
],
target_nids
[
test_idx
],
...
@@ -276,21 +323,25 @@ def track_acc(data):
...
@@ -276,21 +323,25 @@ def track_acc(data):
batch_size
=
batch_size
,
batch_size
=
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
,
drop_last
=
False
,
num_workers
=
num_workers
)
num_workers
=
num_workers
,
)
# node features
# node features
# None for one-hot feature, if not none, it should be the feature tensor.
# None for one-hot feature, if not none, it should be the feature tensor.
embed_layer
=
RelGraphEmbedLayer
(
device
,
embed_layer
=
RelGraphEmbedLayer
(
device
,
g
.
number_of_nodes
(),
g
.
number_of_nodes
(),
node_tids
,
node_tids
,
num_of_ntype
,
num_of_ntype
,
node_feats
,
node_feats
,
n_hidden
,
n_hidden
,
sparse_emb
=
True
)
sparse_emb
=
True
,
)
# create model
# create model
# all model params are in device.
# all model params are in device.
model
=
EntityClassify
(
device
,
model
=
EntityClassify
(
device
,
g
.
number_of_nodes
(),
g
.
number_of_nodes
(),
n_hidden
,
n_hidden
,
num_classes
,
num_classes
,
...
@@ -299,14 +350,19 @@ def track_acc(data):
...
@@ -299,14 +350,19 @@ def track_acc(data):
num_hidden_layers
=
n_layers
-
2
,
num_hidden_layers
=
n_layers
-
2
,
dropout
=
dropout
,
dropout
=
dropout
,
use_self_loop
=
use_self_loop
,
use_self_loop
=
use_self_loop
,
layer_norm
=
False
)
layer_norm
=
False
,
)
embed_layer
=
embed_layer
.
to
(
device
)
embed_layer
=
embed_layer
.
to
(
device
)
model
=
model
.
to
(
device
)
model
=
model
.
to
(
device
)
all_params
=
itertools
.
chain
(
model
.
parameters
(),
embed_layer
.
embeds
.
parameters
())
all_params
=
itertools
.
chain
(
model
.
parameters
(),
embed_layer
.
embeds
.
parameters
()
)
optimizer
=
th
.
optim
.
Adam
(
all_params
,
lr
=
lr
,
weight_decay
=
l2norm
)
optimizer
=
th
.
optim
.
Adam
(
all_params
,
lr
=
lr
,
weight_decay
=
l2norm
)
emb_optimizer
=
th
.
optim
.
SparseAdam
(
list
(
embed_layer
.
node_embeds
.
parameters
()),
lr
=
lr
)
emb_optimizer
=
th
.
optim
.
SparseAdam
(
list
(
embed_layer
.
node_embeds
.
parameters
()),
lr
=
lr
)
print
(
"start training..."
)
print
(
"start training..."
)
for
epoch
in
range
(
n_epochs
):
for
epoch
in
range
(
n_epochs
):
...
@@ -315,12 +371,14 @@ def track_acc(data):
...
@@ -315,12 +371,14 @@ def track_acc(data):
for
i
,
sample_data
in
enumerate
(
train_loader
):
for
i
,
sample_data
in
enumerate
(
train_loader
):
input_nodes
,
output_nodes
,
blocks
=
sample_data
input_nodes
,
output_nodes
,
blocks
=
sample_data
feats
=
embed_layer
(
input_nodes
,
feats
=
embed_layer
(
blocks
[
0
].
srcdata
[
'ntype'
],
input_nodes
,
blocks
[
0
].
srcdata
[
'type_id'
],
blocks
[
0
].
srcdata
[
"ntype"
],
node_feats
)
blocks
[
0
].
srcdata
[
"type_id"
],
node_feats
,
)
logits
=
model
(
blocks
,
feats
)
logits
=
model
(
blocks
,
feats
)
seed_idx
=
blocks
[
-
1
].
dstdata
[
'
type_id
'
]
seed_idx
=
blocks
[
-
1
].
dstdata
[
"
type_id
"
]
loss
=
F
.
cross_entropy
(
logits
,
labels
[
seed_idx
])
loss
=
F
.
cross_entropy
(
logits
,
labels
[
seed_idx
])
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
emb_optimizer
.
zero_grad
()
emb_optimizer
.
zero_grad
()
...
@@ -329,10 +387,14 @@ def track_acc(data):
...
@@ -329,10 +387,14 @@ def track_acc(data):
optimizer
.
step
()
optimizer
.
step
()
emb_optimizer
.
step
()
emb_optimizer
.
step
()
print
(
'
start testing...
'
)
print
(
"
start testing...
"
)
test_logits
,
test_seeds
=
evaluate
(
model
,
embed_layer
,
test_loader
,
node_feats
)
test_logits
,
test_seeds
=
evaluate
(
model
,
embed_layer
,
test_loader
,
node_feats
)
test_loss
=
F
.
cross_entropy
(
test_logits
,
labels
[
test_seeds
].
cpu
()).
item
()
test_loss
=
F
.
cross_entropy
(
test_logits
,
labels
[
test_seeds
].
cpu
()).
item
()
test_acc
=
th
.
sum
(
test_logits
.
argmax
(
dim
=
1
)
==
labels
[
test_seeds
].
cpu
()).
item
()
/
len
(
test_seeds
)
test_acc
=
th
.
sum
(
test_logits
.
argmax
(
dim
=
1
)
==
labels
[
test_seeds
].
cpu
()
).
item
()
/
len
(
test_seeds
)
return
test_acc
return
test_acc
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment