Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ef18dab2
Commit
ef18dab2
authored
Jul 10, 2018
by
Ivan Brugere
Browse files
initial ginipath
Initial example file for ginipath
parent
9219349a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
254 additions
and
0 deletions
+254
-0
.gitignore
.gitignore
+27
-0
examples/pytorch/ginipath.py
examples/pytorch/ginipath.py
+227
-0
No files found.
.gitignore
View file @
ef18dab2
...
...
@@ -105,3 +105,30 @@ venv.bak/
*.swp
*.swo
examples/pytorch/data/ind.pubmed.y
examples/pytorch/data/ind.pubmed.x
examples/pytorch/data/ind.pubmed.ty
examples/pytorch/data/ind.pubmed.tx
examples/pytorch/data/ind.pubmed.test.index
examples/pytorch/data/ind.pubmed.graph
examples/pytorch/data/ind.pubmed.ally
examples/pytorch/data/ind.pubmed.allx
examples/pytorch/data/ind.cora.y
examples/pytorch/data/ind.cora.x
examples/pytorch/data/ind.cora.ty
examples/pytorch/data/ind.cora.tx
examples/pytorch/data/ind.cora.test.index
examples/pytorch/data/ind.cora.graph
examples/pytorch/data/ind.cora.ally
examples/pytorch/data/ind.cora.allx
examples/pytorch/data/ind.citeseer.y
examples/pytorch/data/ind.citeseer.x
examples/pytorch/data/ind.citeseer.ty
examples/pytorch/data/ind.citeseer.tx
examples/pytorch/data/ind.citeseer.test.index
examples/pytorch/data/ind.citeseer.graph
examples/pytorch/data/ind.citeseer.ally
examples/pytorch/data/ind.citeseer.allx
examples/pytorch/.DS_Store
examples/.DS_Store
.DS_Store
examples/pytorch/ginipath.py
0 → 100644
View file @
ef18dab2
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 9 13:34:38 2018
@author: ivabruge
"""
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
"""
import
networkx
as
nx
from
dgl.graph
import
DGLGraph
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
argparse
from
dataset
import
load_data
,
preprocess_features
import
numpy
as
np
class
NodeReduceModule
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
num_hidden
,
num_heads
=
3
,
input_dropout
=
None
,
attention_dropout
=
None
,
act
=
lambda
x
:
F
.
softmax
(
F
.
leaky_relu
(
x
),
dim
=
0
)):
super
(
NodeReduceModule
,
self
).
__init__
()
self
.
num_heads
=
num_heads
self
.
input_dropout
=
input_dropout
self
.
attention_dropout
=
attention_dropout
self
.
act
=
act
self
.
fc
=
nn
.
ModuleList
(
[
nn
.
Linear
(
input_dim
,
num_hidden
,
bias
=
False
)
for
_
in
range
(
num_heads
)])
self
.
attention
=
nn
.
ModuleList
(
[
nn
.
Linear
(
num_hidden
*
2
,
1
,
bias
=
False
)
for
_
in
range
(
num_heads
)])
def
forward
(
self
,
msgs
):
src
,
dst
=
zip
(
*
msgs
)
hu
=
torch
.
cat
(
src
,
dim
=
0
)
# neighbor repr
hv
=
torch
.
cat
(
dst
,
dim
=
0
)
msgs_repr
=
[]
# iterate for each head
for
i
in
range
(
self
.
num_heads
):
# calc W*hself and W*hneigh
hvv
=
self
.
fc
[
i
](
hv
)
huu
=
self
.
fc
[
i
](
hu
)
# calculate W*hself||W*hneigh
h
=
torch
.
cat
((
hvv
,
huu
),
dim
=
1
)
a
=
self
.
act
(
self
.
attention
[
i
](
h
))
if
self
.
attention_dropout
is
not
None
:
a
=
F
.
dropout
(
a
,
self
.
attention_dropout
)
if
self
.
input_dropout
is
not
None
:
hvv
=
F
.
dropout
(
hvv
,
self
.
input_dropout
)
h
=
torch
.
sum
(
a
*
hvv
,
0
,
keepdim
=
True
)
msgs_repr
.
append
(
h
)
return
msgs_repr
class
NodeUpdateModule
(
nn
.
Module
):
def
__init__
(
self
,
residual
,
fc
,
act
,
aggregator
):
super
(
NodeUpdateModule
,
self
).
__init__
()
self
.
residual
=
residual
self
.
fc
=
fc
self
.
act
=
act
self
.
aggregator
=
aggregator
def
forward
(
self
,
node
,
msgs_repr
):
# apply residual connection and activation for each head
for
i
in
range
(
len
(
msgs_repr
)):
if
self
.
residual
:
h
=
self
.
fc
[
i
](
node
[
'h'
])
msgs_repr
[
i
]
=
msgs_repr
[
i
]
+
h
if
self
.
act
is
not
None
:
msgs_repr
[
i
]
=
self
.
act
(
msgs_repr
[
i
])
# aggregate multi-head results
h
=
self
.
aggregator
(
msgs_repr
)
c0
=
torch
.
zeros
(
h
.
shape
)
if
node
[
'c'
]
is
None
:
c0
=
torch
.
zeros
(
h
.
shape
)
else
:
c0
=
node
[
'c'
]
if
node
[
'h_i'
]
is
None
:
h0
=
torch
.
zeros
(
h
.
shape
)
else
:
h0
=
node
[
'h_i'
]
lstm
=
nn
.
LSTM
(
input_size
=
h
.
shape
[
1
],
hidden_size
=
h
.
shape
[
1
],
num_layers
=
1
)
#add dimension to handle sequential (create sequence of length 1)
h
,
(
h_i
,
c
)
=
lstm
(
h
.
unsqueeze
(
0
),
(
h0
.
unsqueeze
(
0
),
c0
.
unsqueeze
(
0
)))
#remove sequential dim
h
=
torch
.
squeeze
(
h
,
0
)
h_i
=
torch
.
squeeze
(
h
,
0
)
c
=
torch
.
squeeze
(
c
,
0
)
return
{
'h'
:
h
,
'c'
:
c
,
'h_i'
:
h_i
}
class
GiniPath
(
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
in_dim
,
num_hidden
,
num_classes
,
num_heads
,
activation
,
input_dropout
,
attention_dropout
,
use_residual
=
False
):
super
(
GiniPath
,
self
).
__init__
()
self
.
input_dropout
=
input_dropout
self
.
reduce_layers
=
nn
.
ModuleList
()
self
.
update_layers
=
nn
.
ModuleList
()
# hidden layers
for
i
in
range
(
num_layers
):
if
i
==
0
:
last_dim
=
in_dim
residual
=
False
else
:
last_dim
=
num_hidden
*
num_heads
# because of concat heads
residual
=
use_residual
self
.
reduce_layers
.
append
(
NodeReduceModule
(
last_dim
,
num_hidden
,
num_heads
,
input_dropout
,
attention_dropout
))
self
.
update_layers
.
append
(
NodeUpdateModule
(
residual
,
self
.
reduce_layers
[
-
1
].
fc
,
activation
,
lambda
x
:
torch
.
cat
(
x
,
1
)))
# projection
self
.
reduce_layers
.
append
(
NodeReduceModule
(
num_hidden
*
num_heads
,
num_classes
,
1
,
input_dropout
,
attention_dropout
))
self
.
update_layers
.
append
(
NodeUpdateModule
(
False
,
self
.
reduce_layers
[
-
1
].
fc
,
None
,
sum
))
def
forward
(
self
,
g
):
g
.
register_message_func
(
lambda
src
,
dst
,
edge
:
(
src
[
'h'
],
dst
[
'h'
]))
for
reduce_func
,
update_func
in
zip
(
self
.
reduce_layers
,
self
.
update_layers
):
# apply dropout
if
self
.
input_dropout
is
not
None
:
# TODO (lingfan): use batched dropout once we have better api
# for global manipulation
for
n
in
g
.
nodes
():
g
.
node
[
n
][
'h'
]
=
F
.
dropout
(
g
.
node
[
n
][
'h'
],
p
=
self
.
input_dropout
)
g
.
node
[
n
][
'c'
]
=
None
g
.
node
[
n
][
'h_i'
]
=
None
g
.
register_reduce_func
(
reduce_func
)
g
.
register_update_func
(
update_func
)
g
.
update_all
()
logits
=
[
g
.
node
[
n
][
'h'
]
for
n
in
g
.
nodes
()]
logits
=
torch
.
cat
(
logits
,
dim
=
0
)
return
logits
def
train
(
self
,
g
,
features
,
labels
,
epochs
,
loss_f
=
torch
.
nn
.
NLLLoss
,
loss_params
=
{},
optimizer
=
torch
.
optim
.
Adam
,
optimizer_parameters
=
None
,
lr
=
0.001
,
ignore
=
[
0
],
quiet
=
False
):
labels
=
torch
.
LongTensor
(
labels
)
print
(
labels
)
_
,
labels
=
torch
.
max
(
labels
,
dim
=
1
)
# convert labels and masks to tensor
if
optimizer_parameters
is
None
:
optimizer_parameters
=
self
.
parameters
()
optimizer_f
=
optimizer
(
optimizer_parameters
,
lr
)
for
epoch
in
range
(
args
.
epochs
):
# reset grad
optimizer_f
.
zero_grad
()
# reset graph states
for
n
in
g
.
nodes
():
g
.
node
[
n
][
'h'
]
=
torch
.
FloatTensor
(
features
[
n
].
toarray
())
# forward
logits
=
self
.
forward
(
g
)
loss
=
loss_f
(
**
loss_params
)
idx
=
[
i
for
i
,
a
in
enumerate
(
labels
)
if
a
not
in
ignore
]
logits
=
logits
[
idx
,
:]
labels
=
labels
[
idx
]
out
=
loss
(
logits
,
labels
)
if
not
quiet
:
print
(
"epoch {} loss: {}"
.
format
(
epoch
,
out
))
out
.
backward
()
optimizer_f
.
step
()
def
main
(
args
):
# dropout parameters
input_dropout
=
0.2
attention_dropout
=
0.2
# load and preprocess dataset
adj
,
features
,
y_train
,
y_val
,
y_test
,
train_mask
,
val_mask
,
test_mask
=
load_data
(
args
.
dataset
)
features
=
preprocess_features
(
features
)
# initialize graph
g
=
DGLGraph
(
adj
)
# create model
model
=
GiniPath
(
args
.
num_layers
,
features
.
shape
[
1
],
args
.
num_hidden
,
y_train
.
shape
[
1
],
args
.
num_heads
,
F
.
elu
,
input_dropout
,
attention_dropout
,
args
.
residual
)
model
.
train
(
g
,
features
,
y_train
,
epochs
=
10
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'GAT'
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
help
=
"dataset name"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
10
,
help
=
"training epoch"
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
3
,
help
=
"number of attentional heads to use"
)
parser
.
add_argument
(
"--num-layers"
,
type
=
int
,
default
=
1
,
help
=
"number of hidden layers"
)
parser
.
add_argument
(
"--num-hidden"
,
type
=
int
,
default
=
8
,
help
=
"size of hidden units"
)
parser
.
add_argument
(
"--residual"
,
action
=
"store_true"
,
help
=
"use residual connection"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.001
,
help
=
"learning rate"
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
args
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment