Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9e46423e
Unverified
Commit
9e46423e
authored
Jun 02, 2022
by
Chang Liu
Committed by
GitHub
Jun 02, 2022
Browse files
[Bugfix] Fix cluster-gat examples (#4068)
Co-authored-by:
Mufei Li
<
mufeili1996@gmail.com
>
parent
d9c25521
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
examples/pytorch/ogb/cluster-gat/main.py
examples/pytorch/ogb/cluster-gat/main.py
+7
-6
No files found.
examples/pytorch/ogb/cluster-gat/main.py
View file @
9e46423e
...
@@ -51,7 +51,7 @@ class GAT(nn.Module):
...
@@ -51,7 +51,7 @@ class GAT(nn.Module):
attn_drop
=
dropout
,
attn_drop
=
dropout
,
activation
=
None
,
activation
=
None
,
negative_slope
=
0.2
))
negative_slope
=
0.2
))
def
forward
(
self
,
g
,
x
):
def
forward
(
self
,
g
,
x
):
h
=
x
h
=
x
for
l
,
conv
in
enumerate
(
self
.
layers
):
for
l
,
conv
in
enumerate
(
self
.
layers
):
...
@@ -119,7 +119,8 @@ def evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device):
...
@@ -119,7 +119,8 @@ def evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device):
with
th
.
no_grad
():
with
th
.
no_grad
():
pred
=
model
.
inference
(
g
,
nfeat
,
batch_size
,
device
)
pred
=
model
.
inference
(
g
,
nfeat
,
batch_size
,
device
)
model
.
train
()
model
.
train
()
return
compute_acc
(
pred
[
val_nid
],
labels
[
val_nid
]),
compute_acc
(
pred
[
test_nid
],
labels
[
test_nid
]),
pred
labels_cpu
=
labels
.
to
(
th
.
device
(
'cpu'
))
return
compute_acc
(
pred
[
val_nid
],
labels_cpu
[
val_nid
]),
compute_acc
(
pred
[
test_nid
],
labels_cpu
[
test_nid
]),
pred
def
model_param_summary
(
model
):
def
model_param_summary
(
model
):
""" Count the model parameters """
""" Count the model parameters """
...
@@ -127,11 +128,10 @@ def model_param_summary(model):
...
@@ -127,11 +128,10 @@ def model_param_summary(model):
print
(
"Total Params {}"
.
format
(
cnt
))
print
(
"Total Params {}"
.
format
(
cnt
))
#### Entry point
#### Entry point
def
run
(
args
,
device
,
data
):
def
run
(
args
,
device
,
data
,
nfeat
):
# Unpack data
# Unpack data
train_nid
,
val_nid
,
test_nid
,
in_feats
,
labels
,
n_classes
,
g
,
cluster_iterator
=
data
train_nid
,
val_nid
,
test_nid
,
in_feats
,
labels
,
n_classes
,
g
,
cluster_iterator
=
data
labels
=
labels
.
to
(
device
)
labels
=
labels
.
to
(
device
)
nfeat
=
g
.
ndata
.
pop
(
'feat'
).
to
(
device
)
# Define model and optimizer
# Define model and optimizer
model
=
GAT
(
in_feats
,
args
.
num_heads
,
args
.
num_hidden
,
n_classes
,
args
.
num_layers
,
F
.
relu
,
args
.
dropout
)
model
=
GAT
(
in_feats
,
args
.
num_heads
,
args
.
num_hidden
,
n_classes
,
args
.
num_layers
,
F
.
relu
,
args
.
dropout
)
...
@@ -200,7 +200,7 @@ def run(args, device, data):
...
@@ -200,7 +200,7 @@ def run(args, device, data):
best_test_acc
=
test_acc
best_test_acc
=
test_acc
print
(
'Best Eval Acc {:.4f} Test Acc {:.4f}'
.
format
(
best_eval_acc
,
best_test_acc
))
print
(
'Best Eval Acc {:.4f} Test Acc {:.4f}'
.
format
(
best_eval_acc
,
best_test_acc
))
print
(
'Avg epoch time: {}'
.
format
(
avg
/
(
epoch
-
4
)))
print
(
'Avg epoch time: {}'
.
format
(
avg
/
(
epoch
-
4
)))
return
best_test_acc
return
best_test_acc
.
to
(
th
.
device
(
'cpu'
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
argparser
=
argparse
.
ArgumentParser
(
"multi-gpu training"
)
argparser
=
argparse
.
ArgumentParser
(
"multi-gpu training"
)
...
@@ -265,6 +265,7 @@ if __name__ == '__main__':
...
@@ -265,6 +265,7 @@ if __name__ == '__main__':
# Run 10 times
# Run 10 times
test_accs
=
[]
test_accs
=
[]
nfeat
=
graph
.
ndata
.
pop
(
'feat'
).
to
(
device
)
for
i
in
range
(
10
):
for
i
in
range
(
10
):
test_accs
.
append
(
run
(
args
,
device
,
data
))
test_accs
.
append
(
run
(
args
,
device
,
data
,
nfeat
))
print
(
'Average test accuracy:'
,
np
.
mean
(
test_accs
),
'±'
,
np
.
std
(
test_accs
))
print
(
'Average test accuracy:'
,
np
.
mean
(
test_accs
),
'±'
,
np
.
std
(
test_accs
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment