Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1248bd24
Commit
1248bd24
authored
Jun 20, 2018
by
Lingfan Yu
Browse files
example gcn model
parent
34fac23d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
215 additions
and
0 deletions
+215
-0
examples/pytorch/dataset.py
examples/pytorch/dataset.py
+101
-0
examples/pytorch/gcn.py
examples/pytorch/gcn.py
+114
-0
No files found.
examples/pytorch/dataset.py
0 → 100644
View file @
1248bd24
import
numpy
as
np
import
pickle
as
pkl
import
networkx
as
nx
import
scipy.sparse
as
sp
import
sys
# (lingfan): following dataset loading and preprocessing code from tkipf/gcn
# https://github.com/tkipf/gcn/blob/master/gcn/utils.py
def
parse_index_file
(
filename
):
"""Parse index file."""
index
=
[]
for
line
in
open
(
filename
):
index
.
append
(
int
(
line
.
strip
()))
return
index
def
sample_mask
(
idx
,
l
):
"""Create mask."""
mask
=
np
.
zeros
(
l
)
mask
[
idx
]
=
1
return
np
.
array
(
mask
,
dtype
=
np
.
bool
)
def
load_data
(
dataset_str
):
"""
Loads input data from gcn/data directory
ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
(a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
object;
ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.
All objects above must be saved using python pickle module.
:param dataset_str: Dataset name
:return: All data input files loaded (as well the training/test data).
"""
names
=
[
'x'
,
'y'
,
'tx'
,
'ty'
,
'allx'
,
'ally'
,
'graph'
]
objects
=
[]
for
i
in
range
(
len
(
names
)):
with
open
(
"data/ind.{}.{}"
.
format
(
dataset_str
,
names
[
i
]),
'rb'
)
as
f
:
if
sys
.
version_info
>
(
3
,
0
):
objects
.
append
(
pkl
.
load
(
f
,
encoding
=
'latin1'
))
else
:
objects
.
append
(
pkl
.
load
(
f
))
x
,
y
,
tx
,
ty
,
allx
,
ally
,
graph
=
tuple
(
objects
)
test_idx_reorder
=
parse_index_file
(
"data/ind.{}.test.index"
.
format
(
dataset_str
))
test_idx_range
=
np
.
sort
(
test_idx_reorder
)
if
dataset_str
==
'citeseer'
:
# Fix citeseer dataset (there are some isolated nodes in the graph)
# Find isolated nodes, add them as zero-vecs into the right position
test_idx_range_full
=
range
(
min
(
test_idx_reorder
),
max
(
test_idx_reorder
)
+
1
)
tx_extended
=
sp
.
lil_matrix
((
len
(
test_idx_range_full
),
x
.
shape
[
1
]))
tx_extended
[
test_idx_range
-
min
(
test_idx_range
),
:]
=
tx
tx
=
tx_extended
ty_extended
=
np
.
zeros
((
len
(
test_idx_range_full
),
y
.
shape
[
1
]))
ty_extended
[
test_idx_range
-
min
(
test_idx_range
),
:]
=
ty
ty
=
ty_extended
features
=
sp
.
vstack
((
allx
,
tx
)).
tolil
()
features
[
test_idx_reorder
,
:]
=
features
[
test_idx_range
,
:]
adj
=
nx
.
adjacency_matrix
(
nx
.
from_dict_of_lists
(
graph
))
labels
=
np
.
vstack
((
ally
,
ty
))
labels
[
test_idx_reorder
,
:]
=
labels
[
test_idx_range
,
:]
idx_test
=
test_idx_range
.
tolist
()
idx_train
=
range
(
len
(
y
))
idx_val
=
range
(
len
(
y
),
len
(
y
)
+
500
)
train_mask
=
sample_mask
(
idx_train
,
labels
.
shape
[
0
])
val_mask
=
sample_mask
(
idx_val
,
labels
.
shape
[
0
])
test_mask
=
sample_mask
(
idx_test
,
labels
.
shape
[
0
])
y_train
=
np
.
zeros
(
labels
.
shape
)
y_val
=
np
.
zeros
(
labels
.
shape
)
y_test
=
np
.
zeros
(
labels
.
shape
)
y_train
[
train_mask
,
:]
=
labels
[
train_mask
,
:]
y_val
[
val_mask
,
:]
=
labels
[
val_mask
,
:]
y_test
[
test_mask
,
:]
=
labels
[
test_mask
,
:]
return
adj
,
features
,
y_train
,
y_val
,
y_test
,
train_mask
,
val_mask
,
test_mask
def
preprocess_features
(
features
):
"""Row-normalize feature matrix and convert to tuple representation"""
rowsum
=
np
.
array
(
features
.
sum
(
1
))
r_inv
=
np
.
power
(
rowsum
,
-
1
).
flatten
()
r_inv
[
np
.
isinf
(
r_inv
)]
=
0.
r_mat_inv
=
sp
.
diags
(
r_inv
)
features
=
r_mat_inv
.
dot
(
features
)
return
features
examples/pytorch/gcn.py
View file @
1248bd24
import
networkx
as
nx
from
dgl.graph
import
DGLGraph
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
argparse
from
dataset
import
load_data
,
preprocess_features
import
numpy
as
np
class
NodeUpdateModule
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
output_dim
,
act
=
None
,
p
=
None
):
super
(
NodeUpdateModule
,
self
).
__init__
()
self
.
linear
=
nn
.
Linear
(
input_dim
,
output_dim
)
self
.
act
=
act
self
.
p
=
p
def
forward
(
self
,
node
,
msgs
):
h
=
node
[
'h'
]
if
self
.
p
is
not
None
:
h
=
F
.
dropout
(
h
,
p
=
self
.
p
)
# aggregator messages
for
msg
in
msgs
:
h
+=
msg
h
=
self
.
linear
(
h
)
if
self
.
act
is
not
None
:
h
=
self
.
act
(
h
)
# (lingfan): Can user directly update node instead of using return statement?
return
{
'h'
:
h
}
class
GCN
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
num_hidden
,
num_classes
,
num_layers
,
activation
,
dropout
):
super
(
GCN
,
self
).
__init__
()
self
.
layers
=
nn
.
ModuleList
()
# hidden layers
last_dim
=
input_dim
for
_
in
range
(
num_layers
):
self
.
layers
.
append
(
NodeUpdateModule
(
last_dim
,
num_hidden
,
act
=
activation
,
p
=
dropout
))
last_dim
=
num_hidden
# output layer
self
.
layers
.
append
(
NodeUpdateModule
(
num_hidden
,
num_classes
,
p
=
dropout
))
def
forward
(
self
,
g
):
g
.
register_message_func
(
lambda
src
,
dst
,
edge
:
src
[
'h'
])
for
layer
in
self
.
layers
:
g
.
register_update_func
(
layer
)
g
.
update_all
()
logits
=
[
g
.
node
[
n
][
'h'
]
for
n
in
g
.
nodes
()]
return
torch
.
cat
(
logits
,
dim
=
0
)
def
main
(
args
):
# load and preprocess dataset
adj
,
features
,
y_train
,
y_val
,
y_test
,
train_mask
,
val_mask
,
test_mask
=
load_data
(
args
.
dataset
)
features
=
preprocess_features
(
features
)
# initialize graph
g
=
DGLGraph
(
adj
)
# create GCN model
model
=
GCN
(
features
.
shape
[
1
],
args
.
num_hidden
,
y_train
.
shape
[
1
],
args
.
num_layers
,
F
.
relu
,
args
.
dropout
)
# use optimizer
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
# convert labels and masks to tensor
labels
=
torch
.
FloatTensor
(
y_train
)
mask
=
torch
.
FloatTensor
(
train_mask
.
astype
(
np
.
float32
))
for
epoch
in
range
(
args
.
epochs
):
# reset grad
optimizer
.
zero_grad
()
# reset graph states
for
n
in
g
.
nodes
():
g
.
node
[
n
][
'h'
]
=
torch
.
FloatTensor
(
features
[
n
].
toarray
())
# forward
logits
=
model
.
forward
(
g
)
# masked cross entropy loss
# TODO: (lingfan) use gather to speed up
logp
=
F
.
log_softmax
(
logits
,
1
)
loss
=
torch
.
mean
(
logp
*
labels
*
mask
.
view
(
-
1
,
1
))
print
(
"epoch {} loss: {}"
.
format
(
epoch
,
loss
.
item
()))
loss
.
backward
()
optimizer
.
step
()
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'GCN'
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
help
=
"dataset name"
)
parser
.
add_argument
(
"--num-layers"
,
type
=
int
,
default
=
1
,
help
=
"number of gcn layers"
)
parser
.
add_argument
(
"--num-hidden"
,
type
=
int
,
default
=
64
,
help
=
"number of hidden units"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
10
,
help
=
"training epoch"
)
parser
.
add_argument
(
"--dropout"
,
type
=
float
,
default
=
None
,
help
=
"dropout probability"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.001
,
help
=
"learning rate"
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment