Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9836f78e
Unverified
Commit
9836f78e
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
autoformat (#5322)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
704bcaf6
Changes
57
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
201 additions
and
142 deletions
+201
-142
benchmarks/benchmarks/api/bench_nn_graphconv.py
benchmarks/benchmarks/api/bench_nn_graphconv.py
+3
-3
benchmarks/benchmarks/api/bench_nn_heterographconv.py
benchmarks/benchmarks/api/bench_nn_heterographconv.py
+3
-3
benchmarks/benchmarks/api/bench_node_subgraph.py
benchmarks/benchmarks/api/bench_node_subgraph.py
+3
-3
benchmarks/benchmarks/api/bench_random_walk.py
benchmarks/benchmarks/api/bench_random_walk.py
+2
-2
benchmarks/benchmarks/api/bench_readout.py
benchmarks/benchmarks/api/bench_readout.py
+2
-2
benchmarks/benchmarks/api/bench_reverse.py
benchmarks/benchmarks/api/bench_reverse.py
+2
-2
benchmarks/benchmarks/api/bench_sample_neighbors.py
benchmarks/benchmarks/api/bench_sample_neighbors.py
+3
-3
benchmarks/benchmarks/api/bench_to_block.py
benchmarks/benchmarks/api/bench_to_block.py
+2
-2
benchmarks/benchmarks/api/bench_udf_apply_edges.py
benchmarks/benchmarks/api/bench_udf_apply_edges.py
+3
-3
benchmarks/benchmarks/api/bench_udf_multi_update_all.py
benchmarks/benchmarks/api/bench_udf_multi_update_all.py
+3
-3
benchmarks/benchmarks/api/bench_udf_update_all.py
benchmarks/benchmarks/api/bench_udf_update_all.py
+3
-3
benchmarks/benchmarks/api/bench_unbatch.py
benchmarks/benchmarks/api/bench_unbatch.py
+2
-2
benchmarks/benchmarks/kernel/bench_edgesoftmax.py
benchmarks/benchmarks/kernel/bench_edgesoftmax.py
+2
-2
benchmarks/benchmarks/kernel/bench_gsddmm_u_dot_v.py
benchmarks/benchmarks/kernel/bench_gsddmm_u_dot_v.py
+2
-2
benchmarks/benchmarks/kernel/bench_gspmm_copy_u.py
benchmarks/benchmarks/kernel/bench_gspmm_copy_u.py
+2
-2
benchmarks/benchmarks/kernel/bench_gspmm_u_mul_e_sum.py
benchmarks/benchmarks/kernel/bench_gspmm_u_mul_e_sum.py
+2
-2
benchmarks/benchmarks/model_acc/bench_gat.py
benchmarks/benchmarks/model_acc/bench_gat.py
+1
-2
benchmarks/benchmarks/model_acc/bench_gcn.py
benchmarks/benchmarks/model_acc/bench_gcn.py
+1
-2
benchmarks/benchmarks/model_acc/bench_gcn_udf.py
benchmarks/benchmarks/model_acc/bench_gcn_udf.py
+1
-2
benchmarks/benchmarks/model_acc/bench_rgcn_ns.py
benchmarks/benchmarks/model_acc/bench_rgcn_ns.py
+159
-97
No files found.
benchmarks/benchmarks/api/bench_nn_graphconv.py
View file @
9836f78e
import
time
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
dgl.nn.pytorch
import
SAGEConv
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_nn_heterographconv.py
View file @
9836f78e
import
time
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
dgl.nn.pytorch
import
HeteroGraphConv
,
SAGEConv
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_node_subgraph.py
View file @
9836f78e
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_random_walk.py
View file @
9836f78e
import
time
import
torch
import
dgl
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_readout.py
View file @
9836f78e
import
time
import
torch
import
dgl
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_reverse.py
View file @
9836f78e
import
time
import
dgl
import
numpy
as
np
import
torch
import
dgl
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_sample_neighbors.py
View file @
9836f78e
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_to_block.py
View file @
9836f78e
import
time
import
dgl
import
numpy
as
np
import
torch
import
dgl
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_udf_apply_edges.py
View file @
9836f78e
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_udf_multi_update_all.py
View file @
9836f78e
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_udf_update_all.py
View file @
9836f78e
import
time
import
numpy
as
np
import
torch
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/api/bench_unbatch.py
View file @
9836f78e
import
time
import
torch
import
dgl
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/kernel/bench_edgesoftmax.py
View file @
9836f78e
import
time
import
torch
import
dgl
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/kernel/bench_gsddmm_u_dot_v.py
View file @
9836f78e
import
time
import
torch
import
dgl
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/kernel/bench_gspmm_copy_u.py
View file @
9836f78e
import
time
import
torch
import
dgl
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/kernel/bench_gspmm_u_mul_e_sum.py
View file @
9836f78e
import
time
import
torch
import
dgl
import
torch
from
..
import
utils
...
...
benchmarks/benchmarks/model_acc/bench_gat.py
View file @
9836f78e
import
dgl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
from
dgl.nn.pytorch
import
GATConv
from
..
import
utils
...
...
benchmarks/benchmarks/model_acc/bench_gcn.py
View file @
9836f78e
import
dgl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
from
dgl.nn.pytorch
import
GraphConv
from
..
import
utils
...
...
benchmarks/benchmarks/model_acc/bench_gcn_udf.py
View file @
9836f78e
import
dgl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
from
..
import
utils
...
...
benchmarks/benchmarks/model_acc/bench_rgcn_ns.py
View file @
9836f78e
import
dgl
import
itertools
import
time
import
dgl
import
dgl.nn.pytorch
as
dglnn
import
torch
as
th
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.multiprocessing
as
mp
from
torch.utils.data
import
DataLoader
import
dgl.nn.pytorch
as
dglnn
from
dgl.nn
import
RelGraphConv
import
time
from
torch.utils.data
import
DataLoader
from
..
import
utils
class
EntityClassify
(
nn
.
Module
):
"""
Entity classification class for RGCN
"""Entity classification class for RGCN
Parameters
----------
device : int
...
...
@@ -35,7 +37,9 @@ class EntityClassify(nn.Module):
use_self_loop : bool
Use self loop if True, default False.
"""
def
__init__
(
self
,
def
__init__
(
self
,
device
,
num_nodes
,
h_dim
,
...
...
@@ -45,7 +49,8 @@ class EntityClassify(nn.Module):
num_hidden_layers
=
1
,
dropout
=
0
,
use_self_loop
=
False
,
layer_norm
=
False
):
layer_norm
=
False
,
):
super
(
EntityClassify
,
self
).
__init__
()
self
.
device
=
device
self
.
num_nodes
=
num_nodes
...
...
@@ -60,22 +65,47 @@ class EntityClassify(nn.Module):
self
.
layers
=
nn
.
ModuleList
()
# i2h
self
.
layers
.
append
(
RelGraphConv
(
self
.
h_dim
,
self
.
h_dim
,
self
.
num_rels
,
"basis"
,
self
.
num_bases
,
activation
=
F
.
relu
,
self_loop
=
self
.
use_self_loop
,
dropout
=
self
.
dropout
,
layer_norm
=
layer_norm
))
self
.
layers
.
append
(
RelGraphConv
(
self
.
h_dim
,
self
.
h_dim
,
self
.
num_rels
,
"basis"
,
self
.
num_bases
,
activation
=
F
.
relu
,
self_loop
=
self
.
use_self_loop
,
dropout
=
self
.
dropout
,
layer_norm
=
layer_norm
,
)
)
# h2h
for
idx
in
range
(
self
.
num_hidden_layers
):
self
.
layers
.
append
(
RelGraphConv
(
self
.
h_dim
,
self
.
h_dim
,
self
.
num_rels
,
"basis"
,
self
.
num_bases
,
activation
=
F
.
relu
,
self_loop
=
self
.
use_self_loop
,
dropout
=
self
.
dropout
,
layer_norm
=
layer_norm
))
self
.
layers
.
append
(
RelGraphConv
(
self
.
h_dim
,
self
.
h_dim
,
self
.
num_rels
,
"basis"
,
self
.
num_bases
,
activation
=
F
.
relu
,
self_loop
=
self
.
use_self_loop
,
dropout
=
self
.
dropout
,
layer_norm
=
layer_norm
,
)
)
# h2o
self
.
layers
.
append
(
RelGraphConv
(
self
.
h_dim
,
self
.
out_dim
,
self
.
num_rels
,
"basis"
,
self
.
num_bases
,
activation
=
None
,
self
.
layers
.
append
(
RelGraphConv
(
self
.
h_dim
,
self
.
out_dim
,
self
.
num_rels
,
"basis"
,
self
.
num_bases
,
activation
=
None
,
self_loop
=
self
.
use_self_loop
,
layer_norm
=
layer_norm
))
layer_norm
=
layer_norm
,
)
)
def
forward
(
self
,
blocks
,
feats
,
norm
=
None
):
if
blocks
is
None
:
...
...
@@ -84,9 +114,10 @@ class EntityClassify(nn.Module):
h
=
feats
for
layer
,
block
in
zip
(
self
.
layers
,
blocks
):
block
=
block
.
to
(
self
.
device
)
h
=
layer
(
block
,
h
,
block
.
edata
[
'
etype
'
],
block
.
edata
[
'
norm
'
])
h
=
layer
(
block
,
h
,
block
.
edata
[
"
etype
"
],
block
.
edata
[
"
norm
"
])
return
h
class
RelGraphEmbedLayer
(
nn
.
Module
):
r
"""Embedding layer for featureless heterograph.
Parameters
...
...
@@ -107,7 +138,9 @@ class RelGraphEmbedLayer(nn.Module):
embed_name : str, optional
Embed name
"""
def
__init__
(
self
,
def
__init__
(
self
,
device
,
num_nodes
,
node_tids
,
...
...
@@ -115,7 +148,8 @@ class RelGraphEmbedLayer(nn.Module):
input_size
,
embed_size
,
sparse_emb
=
False
,
embed_name
=
'embed'
):
embed_name
=
"embed"
,
):
super
(
RelGraphEmbedLayer
,
self
).
__init__
()
self
.
device
=
device
self
.
embed_size
=
embed_size
...
...
@@ -135,7 +169,9 @@ class RelGraphEmbedLayer(nn.Module):
nn
.
init
.
xavier_uniform_
(
embed
)
self
.
embeds
[
str
(
ntype
)]
=
embed
self
.
node_embeds
=
th
.
nn
.
Embedding
(
node_tids
.
shape
[
0
],
self
.
embed_size
,
sparse
=
self
.
sparse_emb
)
self
.
node_embeds
=
th
.
nn
.
Embedding
(
node_tids
.
shape
[
0
],
self
.
embed_size
,
sparse
=
self
.
sparse_emb
)
nn
.
init
.
uniform_
(
self
.
node_embeds
.
weight
,
-
1.0
,
1.0
)
def
forward
(
self
,
node_ids
,
node_tids
,
type_ids
,
features
):
...
...
@@ -157,17 +193,22 @@ class RelGraphEmbedLayer(nn.Module):
embeddings as the input of the next layer
"""
tsd_ids
=
node_ids
.
to
(
self
.
node_embeds
.
weight
.
device
)
embeds
=
th
.
empty
(
node_ids
.
shape
[
0
],
self
.
embed_size
,
device
=
self
.
device
)
embeds
=
th
.
empty
(
node_ids
.
shape
[
0
],
self
.
embed_size
,
device
=
self
.
device
)
for
ntype
in
range
(
self
.
num_of_ntype
):
if
features
[
ntype
]
is
not
None
:
loc
=
node_tids
==
ntype
embeds
[
loc
]
=
features
[
ntype
][
type_ids
[
loc
]].
to
(
self
.
device
)
@
self
.
embeds
[
str
(
ntype
)].
to
(
self
.
device
)
embeds
[
loc
]
=
features
[
ntype
][
type_ids
[
loc
]].
to
(
self
.
device
)
@
self
.
embeds
[
str
(
ntype
)].
to
(
self
.
device
)
else
:
loc
=
node_tids
==
ntype
embeds
[
loc
]
=
self
.
node_embeds
(
tsd_ids
[
loc
]).
to
(
self
.
device
)
return
embeds
def
evaluate
(
model
,
embed_layer
,
eval_loader
,
node_feats
):
model
.
eval
()
embed_layer
.
eval
()
...
...
@@ -178,36 +219,39 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
for
sample_data
in
eval_loader
:
th
.
cuda
.
empty_cache
()
_
,
_
,
blocks
=
sample_data
feats
=
embed_layer
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
feats
=
embed_layer
(
blocks
[
0
].
srcdata
[
dgl
.
NID
],
blocks
[
0
].
srcdata
[
dgl
.
NTYPE
],
blocks
[
0
].
srcdata
[
'type_id'
],
node_feats
)
blocks
[
0
].
srcdata
[
"type_id"
],
node_feats
,
)
logits
=
model
(
blocks
,
feats
)
eval_logits
.
append
(
logits
.
cpu
().
detach
())
eval_seeds
.
append
(
blocks
[
-
1
].
dstdata
[
'
type_id
'
].
cpu
().
detach
())
eval_seeds
.
append
(
blocks
[
-
1
].
dstdata
[
"
type_id
"
].
cpu
().
detach
())
eval_logits
=
th
.
cat
(
eval_logits
)
eval_seeds
=
th
.
cat
(
eval_seeds
)
return
eval_logits
,
eval_seeds
@
utils
.
benchmark
(
'acc'
,
timeout
=
3600
)
# ogbn-mag takes ~1 hour to train
@
utils
.
parametrize
(
'data'
,
[
'am'
,
'ogbn-mag'
])
@
utils
.
benchmark
(
"acc"
,
timeout
=
3600
)
# ogbn-mag takes ~1 hour to train
@
utils
.
parametrize
(
"data"
,
[
"am"
,
"ogbn-mag"
])
def
track_acc
(
data
):
dataset
=
utils
.
process_data
(
data
)
device
=
utils
.
get_bench_device
()
if
data
==
'
am
'
:
if
data
==
"
am
"
:
n_bases
=
40
l2norm
=
5e-4
n_epochs
=
20
elif
data
==
'
ogbn-mag
'
:
elif
data
==
"
ogbn-mag
"
:
n_bases
=
2
l2norm
=
0
n_epochs
=
20
else
:
raise
ValueError
()
fanouts
=
[
25
,
15
]
fanouts
=
[
25
,
15
]
n_layers
=
2
batch_size
=
1024
n_hidden
=
64
...
...
@@ -219,20 +263,20 @@ def track_acc(data):
hg
=
dataset
[
0
]
category
=
dataset
.
predict_category
num_classes
=
dataset
.
num_classes
train_mask
=
hg
.
nodes
[
category
].
data
.
pop
(
'
train_mask
'
)
train_mask
=
hg
.
nodes
[
category
].
data
.
pop
(
"
train_mask
"
)
train_idx
=
th
.
nonzero
(
train_mask
,
as_tuple
=
False
).
squeeze
()
test_mask
=
hg
.
nodes
[
category
].
data
.
pop
(
'
test_mask
'
)
test_mask
=
hg
.
nodes
[
category
].
data
.
pop
(
"
test_mask
"
)
test_idx
=
th
.
nonzero
(
test_mask
,
as_tuple
=
False
).
squeeze
()
labels
=
hg
.
nodes
[
category
].
data
.
pop
(
'
labels
'
).
to
(
device
)
labels
=
hg
.
nodes
[
category
].
data
.
pop
(
"
labels
"
).
to
(
device
)
num_of_ntype
=
len
(
hg
.
ntypes
)
num_rels
=
len
(
hg
.
canonical_etypes
)
node_feats
=
[]
for
ntype
in
hg
.
ntypes
:
if
len
(
hg
.
nodes
[
ntype
].
data
)
==
0
or
'
feat
'
not
in
hg
.
nodes
[
ntype
].
data
:
if
len
(
hg
.
nodes
[
ntype
].
data
)
==
0
or
"
feat
"
not
in
hg
.
nodes
[
ntype
].
data
:
node_feats
.
append
(
None
)
else
:
feat
=
hg
.
nodes
[
ntype
].
data
.
pop
(
'
feat
'
)
feat
=
hg
.
nodes
[
ntype
].
data
.
pop
(
"
feat
"
)
node_feats
.
append
(
feat
.
share_memory_
())
# get target category id
...
...
@@ -241,25 +285,27 @@ def track_acc(data):
if
ntype
==
category
:
category_id
=
i
g
=
dgl
.
to_homogeneous
(
hg
)
u
,
v
,
eid
=
g
.
all_edges
(
form
=
'
all
'
)
u
,
v
,
eid
=
g
.
all_edges
(
form
=
"
all
"
)
# global norm
_
,
inverse_index
,
count
=
th
.
unique
(
v
,
return_inverse
=
True
,
return_counts
=
True
)
_
,
inverse_index
,
count
=
th
.
unique
(
v
,
return_inverse
=
True
,
return_counts
=
True
)
degrees
=
count
[
inverse_index
]
norm
=
th
.
ones
(
eid
.
shape
[
0
])
/
degrees
norm
=
norm
.
unsqueeze
(
1
)
g
.
edata
[
'
norm
'
]
=
norm
g
.
edata
[
'
etype
'
]
=
g
.
edata
[
dgl
.
ETYPE
]
g
.
ndata
[
'
type_id
'
]
=
g
.
ndata
[
dgl
.
NID
]
g
.
ndata
[
'
ntype
'
]
=
g
.
ndata
[
dgl
.
NTYPE
]
g
.
edata
[
"
norm
"
]
=
norm
g
.
edata
[
"
etype
"
]
=
g
.
edata
[
dgl
.
ETYPE
]
g
.
ndata
[
"
type_id
"
]
=
g
.
ndata
[
dgl
.
NID
]
g
.
ndata
[
"
ntype
"
]
=
g
.
ndata
[
dgl
.
NTYPE
]
node_ids
=
th
.
arange
(
g
.
number_of_nodes
())
# find out the target node ids
node_tids
=
g
.
ndata
[
dgl
.
NTYPE
]
loc
=
(
node_tids
==
category_id
)
loc
=
node_tids
==
category_id
target_nids
=
node_ids
[
loc
]
g
=
g
.
formats
(
'
csc
'
)
g
=
g
.
formats
(
"
csc
"
)
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
fanouts
)
train_loader
=
dgl
.
dataloading
.
DataLoader
(
g
,
...
...
@@ -268,7 +314,8 @@ def track_acc(data):
batch_size
=
batch_size
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
num_workers
)
num_workers
=
num_workers
,
)
test_loader
=
dgl
.
dataloading
.
DataLoader
(
g
,
target_nids
[
test_idx
],
...
...
@@ -276,21 +323,25 @@ def track_acc(data):
batch_size
=
batch_size
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
num_workers
)
num_workers
=
num_workers
,
)
# node features
# None for one-hot feature, if not none, it should be the feature tensor.
embed_layer
=
RelGraphEmbedLayer
(
device
,
embed_layer
=
RelGraphEmbedLayer
(
device
,
g
.
number_of_nodes
(),
node_tids
,
num_of_ntype
,
node_feats
,
n_hidden
,
sparse_emb
=
True
)
sparse_emb
=
True
,
)
# create model
# all model params are in device.
model
=
EntityClassify
(
device
,
model
=
EntityClassify
(
device
,
g
.
number_of_nodes
(),
n_hidden
,
num_classes
,
...
...
@@ -299,14 +350,19 @@ def track_acc(data):
num_hidden_layers
=
n_layers
-
2
,
dropout
=
dropout
,
use_self_loop
=
use_self_loop
,
layer_norm
=
False
)
layer_norm
=
False
,
)
embed_layer
=
embed_layer
.
to
(
device
)
model
=
model
.
to
(
device
)
all_params
=
itertools
.
chain
(
model
.
parameters
(),
embed_layer
.
embeds
.
parameters
())
all_params
=
itertools
.
chain
(
model
.
parameters
(),
embed_layer
.
embeds
.
parameters
()
)
optimizer
=
th
.
optim
.
Adam
(
all_params
,
lr
=
lr
,
weight_decay
=
l2norm
)
emb_optimizer
=
th
.
optim
.
SparseAdam
(
list
(
embed_layer
.
node_embeds
.
parameters
()),
lr
=
lr
)
emb_optimizer
=
th
.
optim
.
SparseAdam
(
list
(
embed_layer
.
node_embeds
.
parameters
()),
lr
=
lr
)
print
(
"start training..."
)
for
epoch
in
range
(
n_epochs
):
...
...
@@ -315,12 +371,14 @@ def track_acc(data):
for
i
,
sample_data
in
enumerate
(
train_loader
):
input_nodes
,
output_nodes
,
blocks
=
sample_data
feats
=
embed_layer
(
input_nodes
,
blocks
[
0
].
srcdata
[
'ntype'
],
blocks
[
0
].
srcdata
[
'type_id'
],
node_feats
)
feats
=
embed_layer
(
input_nodes
,
blocks
[
0
].
srcdata
[
"ntype"
],
blocks
[
0
].
srcdata
[
"type_id"
],
node_feats
,
)
logits
=
model
(
blocks
,
feats
)
seed_idx
=
blocks
[
-
1
].
dstdata
[
'
type_id
'
]
seed_idx
=
blocks
[
-
1
].
dstdata
[
"
type_id
"
]
loss
=
F
.
cross_entropy
(
logits
,
labels
[
seed_idx
])
optimizer
.
zero_grad
()
emb_optimizer
.
zero_grad
()
...
...
@@ -329,10 +387,14 @@ def track_acc(data):
optimizer
.
step
()
emb_optimizer
.
step
()
print
(
'
start testing...
'
)
print
(
"
start testing...
"
)
test_logits
,
test_seeds
=
evaluate
(
model
,
embed_layer
,
test_loader
,
node_feats
)
test_logits
,
test_seeds
=
evaluate
(
model
,
embed_layer
,
test_loader
,
node_feats
)
test_loss
=
F
.
cross_entropy
(
test_logits
,
labels
[
test_seeds
].
cpu
()).
item
()
test_acc
=
th
.
sum
(
test_logits
.
argmax
(
dim
=
1
)
==
labels
[
test_seeds
].
cpu
()).
item
()
/
len
(
test_seeds
)
test_acc
=
th
.
sum
(
test_logits
.
argmax
(
dim
=
1
)
==
labels
[
test_seeds
].
cpu
()
).
item
()
/
len
(
test_seeds
)
return
test_acc
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment