Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
929c99ed
Unverified
Commit
929c99ed
authored
Aug 05, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Aug 05, 2020
Browse files
[Model] Rewrite GraphSAGE example (#1938)
parent
879e4ae5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
21 deletions
+21
-21
examples/pytorch/graphsage/train_full.py
examples/pytorch/graphsage/train_full.py
+21
-21
No files found.
examples/pytorch/graphsage/train_full.py
View file @
929c99ed
...
@@ -18,7 +18,6 @@ from dgl.nn.pytorch.conv import SAGEConv
...
@@ -18,7 +18,6 @@ from dgl.nn.pytorch.conv import SAGEConv
class
GraphSAGE
(
nn
.
Module
):
class
GraphSAGE
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
g
,
in_feats
,
in_feats
,
n_hidden
,
n_hidden
,
n_classes
,
n_classes
,
...
@@ -28,27 +27,31 @@ class GraphSAGE(nn.Module):
...
@@ -28,27 +27,31 @@ class GraphSAGE(nn.Module):
aggregator_type
):
aggregator_type
):
super
(
GraphSAGE
,
self
).
__init__
()
super
(
GraphSAGE
,
self
).
__init__
()
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
=
nn
.
ModuleList
()
self
.
g
=
g
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
activation
=
activation
# input layer
# input layer
self
.
layers
.
append
(
SAGEConv
(
in_feats
,
n_hidden
,
aggregator_type
,
feat_drop
=
dropout
,
activation
=
activation
))
self
.
layers
.
append
(
SAGEConv
(
in_feats
,
n_hidden
,
aggregator_type
))
# hidden layers
# hidden layers
for
i
in
range
(
n_layers
-
1
):
for
i
in
range
(
n_layers
-
1
):
self
.
layers
.
append
(
SAGEConv
(
n_hidden
,
n_hidden
,
aggregator_type
,
feat_drop
=
dropout
,
activation
=
activation
))
self
.
layers
.
append
(
SAGEConv
(
n_hidden
,
n_hidden
,
aggregator_type
))
# output layer
# output layer
self
.
layers
.
append
(
SAGEConv
(
n_hidden
,
n_classes
,
aggregator_type
,
feat_drop
=
dropout
,
activation
=
None
))
# activation None
self
.
layers
.
append
(
SAGEConv
(
n_hidden
,
n_classes
,
aggregator_type
))
# activation None
def
forward
(
self
,
features
):
def
forward
(
self
,
graph
,
inputs
):
h
=
features
h
=
self
.
dropout
(
inputs
)
for
layer
in
self
.
layers
:
for
l
,
layer
in
enumerate
(
self
.
layers
):
h
=
layer
(
self
.
g
,
h
)
h
=
layer
(
graph
,
h
)
if
l
!=
len
(
self
.
layers
)
-
1
:
h
=
self
.
activation
(
h
)
h
=
self
.
dropout
(
h
)
return
h
return
h
def
evaluate
(
model
,
features
,
labels
,
mask
):
def
evaluate
(
model
,
graph
,
features
,
labels
,
mask
):
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
=
model
(
features
)
logits
=
model
(
graph
,
features
)
logits
=
logits
[
mask
]
logits
=
logits
[
mask
]
labels
=
labels
[
mask
]
labels
=
labels
[
mask
]
_
,
indices
=
torch
.
max
(
logits
,
dim
=
1
)
_
,
indices
=
torch
.
max
(
logits
,
dim
=
1
)
...
@@ -101,19 +104,16 @@ def main(args):
...
@@ -101,19 +104,16 @@ def main(args):
n_edges
=
g
.
number_of_edges
()
n_edges
=
g
.
number_of_edges
()
# create GraphSAGE model
# create GraphSAGE model
model
=
GraphSAGE
(
g
,
model
=
GraphSAGE
(
in_feats
,
in_feats
,
args
.
n_hidden
,
args
.
n_hidden
,
n_classes
,
n_classes
,
args
.
n_layers
,
args
.
n_layers
,
F
.
relu
,
F
.
relu
,
args
.
dropout
,
args
.
dropout
,
args
.
aggregator_type
args
.
aggregator_type
)
)
if
cuda
:
if
cuda
:
model
.
cuda
()
model
.
cuda
()
loss_fcn
=
torch
.
nn
.
CrossEntropyLoss
()
# use optimizer
# use optimizer
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
)
...
@@ -125,8 +125,8 @@ def main(args):
...
@@ -125,8 +125,8 @@ def main(args):
if
epoch
>=
3
:
if
epoch
>=
3
:
t0
=
time
.
time
()
t0
=
time
.
time
()
# forward
# forward
logits
=
model
(
features
)
logits
=
model
(
g
,
features
)
loss
=
loss_fcn
(
logits
[
train_mask
],
labels
[
train_mask
])
loss
=
F
.
cross_entropy
(
logits
[
train_mask
],
labels
[
train_mask
])
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
...
@@ -135,13 +135,13 @@ def main(args):
...
@@ -135,13 +135,13 @@ def main(args):
if
epoch
>=
3
:
if
epoch
>=
3
:
dur
.
append
(
time
.
time
()
-
t0
)
dur
.
append
(
time
.
time
()
-
t0
)
acc
=
evaluate
(
model
,
features
,
labels
,
val_mask
)
acc
=
evaluate
(
model
,
g
,
features
,
labels
,
val_mask
)
print
(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
print
(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
np
.
mean
(
dur
),
loss
.
item
(),
"ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
np
.
mean
(
dur
),
loss
.
item
(),
acc
,
n_edges
/
np
.
mean
(
dur
)
/
1000
))
acc
,
n_edges
/
np
.
mean
(
dur
)
/
1000
))
print
()
print
()
acc
=
evaluate
(
model
,
features
,
labels
,
test_mask
)
acc
=
evaluate
(
model
,
g
,
features
,
labels
,
test_mask
)
print
(
"Test Accuracy {:.4f}"
.
format
(
acc
))
print
(
"Test Accuracy {:.4f}"
.
format
(
acc
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment