Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4ec8f204
Unverified
Commit
4ec8f204
authored
Mar 05, 2020
by
Mufei Li
Committed by
GitHub
Mar 05, 2020
Browse files
Update (#1319)
parent
066d290f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
7 additions
and
12 deletions
+7
-12
examples/pytorch/gin/README.md
examples/pytorch/gin/README.md
+1
-1
examples/pytorch/gin/dataloader.py
examples/pytorch/gin/dataloader.py
+1
-1
examples/pytorch/gin/main.py
examples/pytorch/gin/main.py
+4
-4
examples/pytorch/gin/parser.py
examples/pytorch/gin/parser.py
+1
-6
No files found.
examples/pytorch/gin/README.md
View file @
4ec8f204
...
...
@@ -6,7 +6,7 @@ Graph Isomorphism Network (GIN)
Dependencies
------------
-
PyTorch 1.
0.1
+
-
PyTorch 1.
1.0
+
-
sklearn
-
tqdm
...
...
examples/pytorch/gin/dataloader.py
View file @
4ec8f204
...
...
@@ -19,7 +19,7 @@ def collate(samples):
for
g
in
graphs
:
# deal with node feats
for
key
in
g
.
node_attr_schemes
().
keys
():
g
.
ndata
[
key
]
=
torch
.
from_numpy
(
g
.
ndata
[
key
]
)
.
float
()
g
.
ndata
[
key
]
=
g
.
ndata
[
key
].
float
()
# no edge feats
batched_graph
=
dgl
.
batch
(
graphs
)
labels
=
torch
.
tensor
(
labels
)
...
...
examples/pytorch/gin/main.py
View file @
4ec8f204
...
...
@@ -73,14 +73,14 @@ def eval_net(args, net, dataloader, criterion):
def
main
(
args
):
# set up seeds, args.seed supported
torch
.
manual_seed
(
seed
=
0
)
np
.
random
.
seed
(
seed
=
0
)
torch
.
manual_seed
(
seed
=
args
.
seed
)
np
.
random
.
seed
(
seed
=
args
.
seed
)
is_cuda
=
not
args
.
disable_cuda
and
torch
.
cuda
.
is_available
()
if
is_cuda
:
args
.
device
=
torch
.
device
(
"cuda:"
+
str
(
args
.
device
))
torch
.
cuda
.
manual_seed_all
(
seed
=
0
)
torch
.
cuda
.
manual_seed_all
(
seed
=
args
.
seed
)
else
:
args
.
device
=
torch
.
device
(
"cpu"
)
...
...
@@ -109,9 +109,9 @@ def main(args):
lrbar
=
tqdm
(
range
(
args
.
epochs
),
unit
=
"epoch"
,
position
=
5
,
ncols
=
0
,
file
=
sys
.
stdout
)
for
epoch
,
_
,
_
in
zip
(
tbar
,
vbar
,
lrbar
):
scheduler
.
step
()
train
(
args
,
model
,
trainloader
,
optimizer
,
criterion
,
epoch
)
scheduler
.
step
()
train_loss
,
train_acc
=
eval_net
(
args
,
model
,
trainloader
,
criterion
)
...
...
examples/pytorch/gin/parser.py
View file @
4ec8f204
...
...
@@ -19,6 +19,7 @@ class Parser():
# dataset
self
.
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
"MUTAG"
,
choices
=
[
'MUTAG'
,
'COLLAB'
,
'IMDBBINARY'
,
'IMDBMULTI'
],
help
=
'name of dataset (default: MUTAG)'
)
self
.
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
,
...
...
@@ -39,9 +40,6 @@ class Parser():
help
=
'which gpu device to use (default: 0)'
)
# net
self
.
parser
.
add_argument
(
'--net'
,
type
=
str
,
default
=
"gin"
,
help
=
'gnn net (default: gin)'
)
self
.
parser
.
add_argument
(
'--num_layers'
,
type
=
int
,
default
=
5
,
help
=
'number of layers (default: 5)'
)
...
...
@@ -64,9 +62,6 @@ class Parser():
self
.
parser
.
add_argument
(
'--learn_eps'
,
action
=
"store_true"
,
help
=
'learn the epsilon weighting'
)
self
.
parser
.
add_argument
(
'--degree_as_tag'
,
action
=
"store_true"
,
help
=
'take the degree of nodes as input feature'
)
# learning
self
.
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment