Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0d650443
"docs/source/api/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "d1d580ecc1d1dfe28f14526f268bb3cbcbb764c9"
Commit
0d650443
authored
Jul 10, 2018
by
Ivan Brugere
Browse files
cleaning
cleaning and documentation
parent
ef18dab2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
15 deletions
+26
-15
examples/pytorch/geniepath.py
examples/pytorch/geniepath.py
+26
-15
No files found.
examples/pytorch/g
i
nipath.py
→
examples/pytorch/g
e
ni
e
path.py
View file @
0d650443
...
...
@@ -4,22 +4,22 @@
Created on Mon Jul 9 13:34:38 2018
@author: ivabruge
"""
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
GeniePath: Graph Neural Networks with Adaptive Receptive Paths
Paper: https://arxiv.org/abs/1802.00910
this model uses an LSTM on the node reductions of the message-passing step
we store the network states at the graph node, since the LSTM variables are not transmitted
"""
import
networkx
as
nx
from
dgl.graph
import
DGLGraph
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
argparse
from
dataset
import
load_data
,
preprocess_features
import
numpy
as
np
class
NodeReduceModule
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
num_hidden
,
num_heads
=
3
,
input_dropout
=
None
,
...
...
@@ -101,10 +101,10 @@ class NodeUpdateModule(nn.Module):
return
{
'h'
:
h
,
'c'
:
c
,
'h_i'
:
h_i
}
class
G
i
niPath
(
nn
.
Module
):
class
G
e
ni
e
Path
(
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
in_dim
,
num_hidden
,
num_classes
,
num_heads
,
activation
,
input_dropout
,
attention_dropout
,
use_residual
=
False
):
super
(
G
i
niPath
,
self
).
__init__
()
super
(
G
e
ni
e
Path
,
self
).
__init__
()
self
.
input_dropout
=
input_dropout
self
.
reduce_layers
=
nn
.
ModuleList
()
...
...
@@ -147,15 +147,18 @@ class GiniPath(nn.Module):
logits
=
[
g
.
node
[
n
][
'h'
]
for
n
in
g
.
nodes
()]
logits
=
torch
.
cat
(
logits
,
dim
=
0
)
return
logits
#train on graph g with features, and target labels. Accepts a loss function and an optimizer function which implements optimizer.step()
def
train
(
self
,
g
,
features
,
labels
,
epochs
,
loss_f
=
torch
.
nn
.
NLLLoss
,
loss_params
=
{},
optimizer
=
torch
.
optim
.
Adam
,
optimizer_parameters
=
None
,
lr
=
0.001
,
ignore
=
[
0
],
quiet
=
False
):
labels
=
torch
.
LongTensor
(
labels
)
print
(
labels
)
_
,
labels
=
torch
.
max
(
labels
,
dim
=
1
)
# convert labels and masks to tensor
if
optimizer_parameters
is
None
:
optimizer_parameters
=
self
.
parameters
()
#instantiate optimizer on given params
optimizer_f
=
optimizer
(
optimizer_parameters
,
lr
)
for
epoch
in
range
(
args
.
epochs
):
...
...
@@ -168,8 +171,11 @@ class GiniPath(nn.Module):
# forward
logits
=
self
.
forward
(
g
)
#intantiate loss on passed parameters (e.g. class weight params)
loss
=
loss_f
(
**
loss_params
)
#trim null labels
idx
=
[
i
for
i
,
a
in
enumerate
(
labels
)
if
a
not
in
ignore
]
logits
=
logits
[
idx
,
:]
labels
=
labels
[
idx
]
...
...
@@ -183,8 +189,8 @@ class GiniPath(nn.Module):
def
main
(
args
):
# dropout parameters
input_dropout
=
0.2
attention_dropout
=
0.2
input_dropout
=
args
.
idrop
attention_dropout
=
args
.
adrop
# load and preprocess dataset
adj
,
features
,
y_train
,
y_val
,
y_test
,
train_mask
,
val_mask
,
test_mask
=
load_data
(
args
.
dataset
)
...
...
@@ -194,7 +200,7 @@ def main(args):
g
=
DGLGraph
(
adj
)
# create model
model
=
G
i
niPath
(
args
.
num_layers
,
model
=
G
e
ni
e
Path
(
args
.
num_layers
,
features
.
shape
[
1
],
args
.
num_hidden
,
y_train
.
shape
[
1
],
...
...
@@ -203,7 +209,7 @@ def main(args):
input_dropout
,
attention_dropout
,
args
.
residual
)
model
.
train
(
g
,
features
,
y_train
,
epochs
=
10
)
model
.
train
(
g
,
features
,
y_train
,
epochs
=
args
.
epochs
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'GAT'
)
...
...
@@ -221,6 +227,11 @@ if __name__ == '__main__':
help
=
"use residual connection"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.001
,
help
=
"learning rate"
)
parser
.
add_argument
(
"--idrop"
,
type
=
float
,
default
=
0.2
,
help
=
"Input dropout"
)
parser
.
add_argument
(
"--adrop"
,
type
=
float
,
default
=
0.2
,
help
=
"attention dropout"
)
args
=
parser
.
parse_args
()
print
(
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment