Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ee241699
Unverified
Commit
ee241699
authored
Aug 09, 2018
by
Minjie Wang
Committed by
GitHub
Aug 09, 2018
Browse files
GAT model (#37)
* GAT model * fix output projection to have only one head
parent
4673b96f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
485 additions
and
225 deletions
+485
-225
examples/pytorch/gat.py
examples/pytorch/gat.py
+0
-193
examples/pytorch/gat/gat.py
examples/pytorch/gat/gat.py
+222
-0
examples/pytorch/gat/gat_batch.py
examples/pytorch/gat/gat_batch.py
+222
-0
python/dgl/graph.py
python/dgl/graph.py
+40
-20
tests/test_basics2.py
tests/test_basics2.py
+1
-12
No files found.
examples/pytorch/gat.py
deleted
100644 → 0
View file @
4673b96f
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
"""
import
networkx
as
nx
from
dgl.graph
import
DGLGraph
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
argparse
from
dataset
import
load_data
,
preprocess_features
import
numpy
as
np
class
NodeReduceModule
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
num_hidden
,
num_heads
=
3
,
input_dropout
=
None
,
attention_dropout
=
None
):
super
(
NodeReduceModule
,
self
).
__init__
()
self
.
num_heads
=
num_heads
self
.
input_dropout
=
input_dropout
self
.
attention_dropout
=
attention_dropout
self
.
fc
=
nn
.
ModuleList
(
[
nn
.
Linear
(
input_dim
,
num_hidden
,
bias
=
False
)
for
_
in
range
(
num_heads
)])
self
.
attention
=
nn
.
ModuleList
(
[
nn
.
Linear
(
num_hidden
*
2
,
1
,
bias
=
False
)
for
_
in
range
(
num_heads
)])
def
forward
(
self
,
msgs
):
src
,
dst
=
zip
(
*
msgs
)
hu
=
torch
.
cat
(
src
,
dim
=
0
)
# neighbor repr
hv
=
torch
.
cat
(
dst
,
dim
=
0
)
msgs_repr
=
[]
# iterate for each head
for
i
in
range
(
self
.
num_heads
):
# calc W*hself and W*hneigh
hvv
=
self
.
fc
[
i
](
hv
)
huu
=
self
.
fc
[
i
](
hu
)
# calculate W*hself||W*hneigh
h
=
torch
.
cat
((
hvv
,
huu
),
dim
=
1
)
a
=
F
.
leaky_relu
(
self
.
attention
[
i
](
h
))
a
=
F
.
softmax
(
a
,
dim
=
0
)
if
self
.
attention_dropout
is
not
None
:
a
=
F
.
dropout
(
a
,
self
.
attention_dropout
)
if
self
.
input_dropout
is
not
None
:
hvv
=
F
.
dropout
(
hvv
,
self
.
input_dropout
)
h
=
torch
.
sum
(
a
*
hvv
,
0
,
keepdim
=
True
)
msgs_repr
.
append
(
h
)
return
msgs_repr
class
NodeUpdateModule
(
nn
.
Module
):
def
__init__
(
self
,
residual
,
fc
,
act
,
aggregator
):
super
(
NodeUpdateModule
,
self
).
__init__
()
self
.
residual
=
residual
self
.
fc
=
fc
self
.
act
=
act
self
.
aggregator
=
aggregator
def
forward
(
self
,
node
,
msgs_repr
):
# apply residual connection and activation for each head
for
i
in
range
(
len
(
msgs_repr
)):
if
self
.
residual
:
h
=
self
.
fc
[
i
](
node
[
'h'
])
msgs_repr
[
i
]
=
msgs_repr
[
i
]
+
h
if
self
.
act
is
not
None
:
msgs_repr
[
i
]
=
self
.
act
(
msgs_repr
[
i
])
# aggregate multi-head results
h
=
self
.
aggregator
(
msgs_repr
)
return
{
'h'
:
h
}
class
GAT
(
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
in_dim
,
num_hidden
,
num_classes
,
num_heads
,
activation
,
input_dropout
,
attention_dropout
,
use_residual
=
False
):
super
(
GAT
,
self
).
__init__
()
self
.
input_dropout
=
input_dropout
self
.
reduce_layers
=
nn
.
ModuleList
()
self
.
update_layers
=
nn
.
ModuleList
()
# hidden layers
for
i
in
range
(
num_layers
):
if
i
==
0
:
last_dim
=
in_dim
residual
=
False
else
:
last_dim
=
num_hidden
*
num_heads
# because of concat heads
residual
=
use_residual
self
.
reduce_layers
.
append
(
NodeReduceModule
(
last_dim
,
num_hidden
,
num_heads
,
input_dropout
,
attention_dropout
))
self
.
update_layers
.
append
(
NodeUpdateModule
(
residual
,
self
.
reduce_layers
[
-
1
].
fc
,
activation
,
lambda
x
:
torch
.
cat
(
x
,
1
)))
# projection
self
.
reduce_layers
.
append
(
NodeReduceModule
(
num_hidden
*
num_heads
,
num_classes
,
1
,
input_dropout
,
attention_dropout
))
self
.
update_layers
.
append
(
NodeUpdateModule
(
False
,
self
.
reduce_layers
[
-
1
].
fc
,
None
,
sum
))
def
forward
(
self
,
g
):
g
.
register_message_func
(
lambda
src
,
dst
,
edge
:
(
src
[
'h'
],
dst
[
'h'
]))
for
reduce_func
,
update_func
in
zip
(
self
.
reduce_layers
,
self
.
update_layers
):
# apply dropout
if
self
.
input_dropout
is
not
None
:
# TODO (lingfan): use batched dropout once we have better api
# for global manipulation
for
n
in
g
.
nodes
():
g
.
node
[
n
][
'h'
]
=
F
.
dropout
(
g
.
node
[
n
][
'h'
],
p
=
self
.
input_dropout
)
g
.
register_reduce_func
(
reduce_func
)
g
.
register_update_func
(
update_func
)
g
.
update_all
()
logits
=
[
g
.
node
[
n
][
'h'
]
for
n
in
g
.
nodes
()]
logits
=
torch
.
cat
(
logits
,
dim
=
0
)
return
logits
def
main
(
args
):
# dropout parameters
input_dropout
=
0.2
attention_dropout
=
0.2
# load and preprocess dataset
adj
,
features
,
y_train
,
y_val
,
y_test
,
train_mask
,
val_mask
,
test_mask
=
load_data
(
args
.
dataset
)
features
=
preprocess_features
(
features
)
# initialize graph
g
=
DGLGraph
(
adj
)
# create model
model
=
GAT
(
args
.
num_layers
,
features
.
shape
[
1
],
args
.
num_hidden
,
y_train
.
shape
[
1
],
args
.
num_heads
,
F
.
elu
,
input_dropout
,
attention_dropout
,
args
.
residual
)
# use optimizer
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
# convert labels and masks to tensor
labels
=
torch
.
FloatTensor
(
y_train
)
mask
=
torch
.
FloatTensor
(
train_mask
.
astype
(
np
.
float32
))
n_train
=
torch
.
sum
(
mask
)
for
epoch
in
range
(
args
.
epochs
):
# reset grad
optimizer
.
zero_grad
()
# reset graph states
for
n
in
g
.
nodes
():
g
.
node
[
n
][
'h'
]
=
torch
.
FloatTensor
(
features
[
n
].
toarray
())
# forward
logits
=
model
.
forward
(
g
)
# masked cross entropy loss
# TODO: (lingfan) use gather to speed up
logp
=
F
.
log_softmax
(
logits
,
1
)
loss
=
-
torch
.
sum
(
logp
*
labels
*
mask
.
view
(
-
1
,
1
))
/
n_train
print
(
"epoch {} loss: {}"
.
format
(
epoch
,
loss
.
item
()))
loss
.
backward
()
optimizer
.
step
()
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'GAT'
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
help
=
"dataset name"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
10
,
help
=
"training epoch"
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
3
,
help
=
"number of attentional heads to use"
)
parser
.
add_argument
(
"--num-layers"
,
type
=
int
,
default
=
1
,
help
=
"number of hidden layers"
)
parser
.
add_argument
(
"--num-hidden"
,
type
=
int
,
default
=
8
,
help
=
"size of hidden units"
)
parser
.
add_argument
(
"--residual"
,
action
=
"store_true"
,
help
=
"use residual connection"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.001
,
help
=
"learning rate"
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
args
)
examples/pytorch/gat/gat.py
0 → 100644
View file @
ee241699
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
"""
import
argparse
import
numpy
as
np
import
time
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
dgl
import
DGLGraph
from
dgl.data
import
register_data_args
,
load_data
def
gat_message
(
src
,
edge
):
return
{
'ft'
:
src
[
'ft'
],
'a2'
:
src
[
'a2'
]}
class
GATReduce
(
nn
.
Module
):
def
__init__
(
self
,
attn_drop
):
super
(
GATReduce
,
self
).
__init__
()
self
.
attn_drop
=
attn_drop
def
forward
(
self
,
node
,
msgs
):
a1
=
torch
.
unsqueeze
(
node
[
'a1'
],
0
)
# shape (1, 1)
a2
=
torch
.
cat
([
torch
.
unsqueeze
(
m
[
'a2'
],
0
)
for
m
in
msgs
],
dim
=
0
)
# shape (deg, 1)
ft
=
torch
.
cat
([
torch
.
unsqueeze
(
m
[
'ft'
],
0
)
for
m
in
msgs
],
dim
=
0
)
# shape (deg, D)
# attention
a
=
a1
+
a2
# shape (deg, 1)
e
=
F
.
softmax
(
F
.
leaky_relu
(
a
),
dim
=
0
)
if
self
.
attn_drop
!=
0.0
:
e
=
F
.
dropout
(
e
,
self
.
attn_drop
)
return
torch
.
sum
(
e
*
ft
,
dim
=
0
)
# shape (D,)
class
GATFinalize
(
nn
.
Module
):
def
__init__
(
self
,
headid
,
indim
,
hiddendim
,
activation
,
residual
):
super
(
GATFinalize
,
self
).
__init__
()
self
.
headid
=
headid
self
.
activation
=
activation
self
.
residual
=
residual
self
.
residual_fc
=
None
if
residual
:
if
indim
!=
hiddendim
:
self
.
residual_fc
=
nn
.
Linear
(
indim
,
hiddendim
)
def
forward
(
self
,
node
,
accum
):
ret
=
accum
if
self
.
residual
:
if
self
.
residual_fc
is
not
None
:
ret
=
self
.
residual_fc
(
node
[
'h'
])
+
ret
else
:
ret
=
node
[
'h'
]
+
ret
return
{
'head%d'
%
self
.
headid
:
self
.
activation
(
ret
)}
class
GATPrepare
(
nn
.
Module
):
def
__init__
(
self
,
indim
,
hiddendim
,
drop
):
super
(
GATPrepare
,
self
).
__init__
()
self
.
fc
=
nn
.
Linear
(
indim
,
hiddendim
)
self
.
drop
=
drop
self
.
attn_l
=
nn
.
Linear
(
hiddendim
,
1
)
self
.
attn_r
=
nn
.
Linear
(
hiddendim
,
1
)
def
forward
(
self
,
feats
):
h
=
feats
if
self
.
drop
!=
0.0
:
h
=
F
.
dropout
(
h
,
self
.
drop
)
ft
=
self
.
fc
(
h
)
a1
=
self
.
attn_l
(
ft
)
a2
=
self
.
attn_r
(
ft
)
return
{
'h'
:
h
,
'ft'
:
ft
,
'a1'
:
a1
,
'a2'
:
a2
}
class
GAT
(
nn
.
Module
):
def
__init__
(
self
,
nx_graph
,
num_layers
,
in_dim
,
num_hidden
,
num_classes
,
num_heads
,
activation
,
in_drop
,
attn_drop
,
residual
):
super
(
GAT
,
self
).
__init__
()
self
.
g
=
DGLGraph
(
nx_graph
)
self
.
num_layers
=
num_layers
# one extra output projection
self
.
num_heads
=
num_heads
self
.
prp
=
nn
.
ModuleList
()
self
.
red
=
nn
.
ModuleList
()
self
.
fnl
=
nn
.
ModuleList
()
# input projection (no residual)
for
hid
in
range
(
num_heads
):
self
.
prp
.
append
(
GATPrepare
(
in_dim
,
num_hidden
,
in_drop
))
self
.
red
.
append
(
GATReduce
(
attn_drop
))
self
.
fnl
.
append
(
GATFinalize
(
hid
,
in_dim
,
num_hidden
,
activation
,
False
))
# hidden layers
for
l
in
range
(
num_layers
-
1
):
for
hid
in
range
(
num_heads
):
# due to multi-head, the in_dim = num_hidden * num_heads
self
.
prp
.
append
(
GATPrepare
(
num_hidden
*
num_heads
,
num_hidden
,
in_drop
))
self
.
red
.
append
(
GATReduce
(
attn_drop
))
self
.
fnl
.
append
(
GATFinalize
(
hid
,
num_hidden
*
num_heads
,
num_hidden
,
activation
,
residual
))
# output projection
self
.
prp
.
append
(
GATPrepare
(
num_hidden
*
num_heads
,
num_classes
,
in_drop
))
self
.
red
.
append
(
GATReduce
(
attn_drop
))
self
.
fnl
.
append
(
GATFinalize
(
0
,
num_hidden
*
num_heads
,
num_classes
,
activation
,
residual
))
# sanity check
assert
len
(
self
.
prp
)
==
self
.
num_layers
*
self
.
num_heads
+
1
assert
len
(
self
.
red
)
==
self
.
num_layers
*
self
.
num_heads
+
1
assert
len
(
self
.
fnl
)
==
self
.
num_layers
*
self
.
num_heads
+
1
def
forward
(
self
,
features
,
train_nodes
):
last
=
features
for
l
in
range
(
self
.
num_layers
):
for
hid
in
range
(
self
.
num_heads
):
i
=
l
*
self
.
num_heads
+
hid
# prepare
for
n
,
h
in
last
.
items
():
self
.
g
.
nodes
[
n
].
update
(
self
.
prp
[
i
](
h
))
# message passing
self
.
g
.
update_all
(
gat_message
,
self
.
red
[
i
],
self
.
fnl
[
i
])
# merge all the heads
last
=
{}
for
n
in
self
.
g
.
nodes
():
last
[
n
]
=
torch
.
cat
(
[
self
.
g
.
nodes
[
n
][
'head%d'
%
hid
]
for
hid
in
range
(
self
.
num_heads
)])
# output projection
for
n
,
h
in
last
.
items
():
self
.
g
.
nodes
[
n
].
update
(
self
.
prp
[
-
1
](
h
))
self
.
g
.
update_all
(
gat_message
,
self
.
red
[
-
1
],
self
.
fnl
[
-
1
])
return
torch
.
cat
([
torch
.
unsqueeze
(
self
.
g
.
nodes
[
n
][
'head0'
],
0
)
for
n
in
train_nodes
])
def
main
(
args
):
# load and preprocess dataset
data
=
load_data
(
args
)
# features of each samples
features
=
{}
labels
=
[]
train_nodes
=
[]
for
n
in
data
.
graph
.
nodes
():
features
[
n
]
=
torch
.
FloatTensor
(
data
.
features
[
n
,
:])
if
data
.
train_mask
[
n
]
==
1
:
train_nodes
.
append
(
n
)
labels
.
append
(
data
.
labels
[
n
])
labels
=
torch
.
LongTensor
(
labels
)
in_feats
=
data
.
features
.
shape
[
1
]
n_classes
=
data
.
num_labels
n_edges
=
data
.
graph
.
number_of_edges
()
if
args
.
gpu
<
0
:
cuda
=
False
else
:
cuda
=
True
torch
.
cuda
.
set_device
(
args
.
gpu
)
features
=
{
k
:
v
.
cuda
()
for
k
,
v
in
features
.
items
()}
labels
=
labels
.
cuda
()
# create model
model
=
GAT
(
data
.
graph
,
args
.
num_layers
,
in_feats
,
args
.
num_hidden
,
n_classes
,
args
.
num_heads
,
F
.
elu
,
args
.
in_drop
,
args
.
attn_drop
,
args
.
residual
)
if
cuda
:
model
.
cuda
()
# use optimizer
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
# initialize graph
dur
=
[]
for
epoch
in
range
(
args
.
epochs
):
if
epoch
>=
3
:
t0
=
time
.
time
()
# forward
logits
=
model
(
features
,
train_nodes
)
logp
=
F
.
log_softmax
(
logits
,
1
)
loss
=
F
.
nll_loss
(
logp
,
labels
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
if
epoch
>=
3
:
dur
.
append
(
time
.
time
()
-
t0
)
print
(
"Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
loss
.
item
(),
np
.
mean
(
dur
),
n_edges
/
np
.
mean
(
dur
)
/
1000
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'GAT'
)
register_data_args
(
parser
)
parser
.
add_argument
(
"--gpu"
,
type
=
int
,
default
=-
1
,
help
=
"Which GPU to use. Set -1 to use CPU."
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
20
,
help
=
"number of training epochs"
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
,
help
=
"number of attentional heads to use"
)
parser
.
add_argument
(
"--num-layers"
,
type
=
int
,
default
=
1
,
help
=
"number of hidden layers"
)
parser
.
add_argument
(
"--num-hidden"
,
type
=
int
,
default
=
8
,
help
=
"size of hidden units"
)
parser
.
add_argument
(
"--residual"
,
action
=
"store_false"
,
help
=
"use residual connection"
)
parser
.
add_argument
(
"--in-drop"
,
type
=
float
,
default
=
.
6
,
help
=
"input feature dropout"
)
parser
.
add_argument
(
"--attn-drop"
,
type
=
float
,
default
=
.
6
,
help
=
"attention dropout"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.005
,
help
=
"learning rate"
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
args
)
examples/pytorch/gat/gat_batch.py
0 → 100644
View file @
ee241699
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
GAT with batch processing
"""
import
argparse
import
numpy
as
np
import
time
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
from
dgl
import
DGLGraph
from
dgl.data
import
register_data_args
,
load_data
def
gat_message
(
src
,
edge
):
return
{
'ft'
:
src
[
'ft'
],
'a2'
:
src
[
'a2'
]}
class
GATReduce
(
nn
.
Module
):
def
__init__
(
self
,
attn_drop
):
super
(
GATReduce
,
self
).
__init__
()
self
.
attn_drop
=
attn_drop
def
forward
(
self
,
node
,
msgs
):
a1
=
torch
.
unsqueeze
(
node
[
'a1'
],
1
)
# shape (B, 1, 1)
a2
=
msgs
[
'a2'
]
# shape (B, deg, 1)
ft
=
msgs
[
'ft'
]
# shape (B, deg, D)
# attention
a
=
a1
+
a2
# shape (B, deg, 1)
e
=
F
.
softmax
(
F
.
leaky_relu
(
a
),
dim
=
1
)
if
self
.
attn_drop
!=
0.0
:
e
=
F
.
dropout
(
e
,
self
.
attn_drop
)
return
torch
.
sum
(
e
*
ft
,
dim
=
1
)
# shape (B, D)
class
GATFinalize
(
nn
.
Module
):
def
__init__
(
self
,
headid
,
indim
,
hiddendim
,
activation
,
residual
):
super
(
GATFinalize
,
self
).
__init__
()
self
.
headid
=
headid
self
.
activation
=
activation
self
.
residual
=
residual
self
.
residual_fc
=
None
if
residual
:
if
indim
!=
hiddendim
:
self
.
residual_fc
=
nn
.
Linear
(
indim
,
hiddendim
)
def
forward
(
self
,
node
,
accum
):
ret
=
accum
if
self
.
residual
:
if
self
.
residual_fc
is
not
None
:
ret
=
self
.
residual_fc
(
node
[
'h'
])
+
ret
else
:
ret
=
node
[
'h'
]
+
ret
return
{
'head%d'
%
self
.
headid
:
self
.
activation
(
ret
)}
class
GATPrepare
(
nn
.
Module
):
def
__init__
(
self
,
indim
,
hiddendim
,
drop
):
super
(
GATPrepare
,
self
).
__init__
()
self
.
fc
=
nn
.
Linear
(
indim
,
hiddendim
)
self
.
drop
=
drop
self
.
attn_l
=
nn
.
Linear
(
hiddendim
,
1
)
self
.
attn_r
=
nn
.
Linear
(
hiddendim
,
1
)
def
forward
(
self
,
feats
):
h
=
feats
if
self
.
drop
!=
0.0
:
h
=
F
.
dropout
(
h
,
self
.
drop
)
ft
=
self
.
fc
(
h
)
a1
=
self
.
attn_l
(
ft
)
a2
=
self
.
attn_r
(
ft
)
return
{
'h'
:
h
,
'ft'
:
ft
,
'a1'
:
a1
,
'a2'
:
a2
}
class
GAT
(
nn
.
Module
):
def
__init__
(
self
,
g
,
num_layers
,
in_dim
,
num_hidden
,
num_classes
,
num_heads
,
activation
,
in_drop
,
attn_drop
,
residual
):
super
(
GAT
,
self
).
__init__
()
self
.
g
=
g
self
.
num_layers
=
num_layers
self
.
num_heads
=
num_heads
self
.
prp
=
nn
.
ModuleList
()
self
.
red
=
nn
.
ModuleList
()
self
.
fnl
=
nn
.
ModuleList
()
# input projection (no residual)
for
hid
in
range
(
num_heads
):
self
.
prp
.
append
(
GATPrepare
(
in_dim
,
num_hidden
,
in_drop
))
self
.
red
.
append
(
GATReduce
(
attn_drop
))
self
.
fnl
.
append
(
GATFinalize
(
hid
,
in_dim
,
num_hidden
,
activation
,
False
))
# hidden layers
for
l
in
range
(
num_layers
-
1
):
for
hid
in
range
(
num_heads
):
# due to multi-head, the in_dim = num_hidden * num_heads
self
.
prp
.
append
(
GATPrepare
(
num_hidden
*
num_heads
,
num_hidden
,
in_drop
))
self
.
red
.
append
(
GATReduce
(
attn_drop
))
self
.
fnl
.
append
(
GATFinalize
(
hid
,
num_hidden
*
num_heads
,
num_hidden
,
activation
,
residual
))
# output projection
self
.
prp
.
append
(
GATPrepare
(
num_hidden
*
num_heads
,
num_classes
,
in_drop
))
self
.
red
.
append
(
GATReduce
(
attn_drop
))
self
.
fnl
.
append
(
GATFinalize
(
0
,
num_hidden
*
num_heads
,
num_classes
,
activation
,
residual
))
# sanity check
assert
len
(
self
.
prp
)
==
self
.
num_layers
*
self
.
num_heads
+
1
assert
len
(
self
.
red
)
==
self
.
num_layers
*
self
.
num_heads
+
1
assert
len
(
self
.
fnl
)
==
self
.
num_layers
*
self
.
num_heads
+
1
def
forward
(
self
,
features
):
last
=
features
for
l
in
range
(
self
.
num_layers
):
for
hid
in
range
(
self
.
num_heads
):
i
=
l
*
self
.
num_heads
+
hid
# prepare
self
.
g
.
set_n_repr
(
self
.
prp
[
i
](
last
))
# message passing
self
.
g
.
update_all
(
gat_message
,
self
.
red
[
i
],
self
.
fnl
[
i
],
batchable
=
True
)
# merge all the heads
last
=
torch
.
cat
(
[
self
.
g
.
pop_n_repr
(
'head%d'
%
hid
)
for
hid
in
range
(
self
.
num_heads
)],
dim
=
1
)
# output projection
self
.
g
.
set_n_repr
(
self
.
prp
[
-
1
](
last
))
self
.
g
.
update_all
(
gat_message
,
self
.
red
[
-
1
],
self
.
fnl
[
-
1
],
batchable
=
True
)
return
self
.
g
.
pop_n_repr
(
'head0'
)
def
main
(
args
):
# load and preprocess dataset
data
=
load_data
(
args
)
features
=
torch
.
FloatTensor
(
data
.
features
)
labels
=
torch
.
LongTensor
(
data
.
labels
)
mask
=
torch
.
ByteTensor
(
data
.
train_mask
)
in_feats
=
features
.
shape
[
1
]
n_classes
=
data
.
num_labels
n_edges
=
data
.
graph
.
number_of_edges
()
if
args
.
gpu
<
0
:
cuda
=
False
else
:
cuda
=
True
torch
.
cuda
.
set_device
(
args
.
gpu
)
features
=
features
.
cuda
()
labels
=
labels
.
cuda
()
mask
=
mask
.
cuda
()
# create GCN model
g
=
DGLGraph
(
data
.
graph
)
if
cuda
:
g
.
set_device
(
dgl
.
gpu
(
args
.
gpu
))
# create model
model
=
GAT
(
g
,
args
.
num_layers
,
in_feats
,
args
.
num_hidden
,
n_classes
,
args
.
num_heads
,
F
.
elu
,
args
.
in_drop
,
args
.
attn_drop
,
args
.
residual
)
if
cuda
:
model
.
cuda
()
# use optimizer
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
# initialize graph
dur
=
[]
for
epoch
in
range
(
args
.
epochs
):
if
epoch
>=
3
:
t0
=
time
.
time
()
# forward
logits
=
model
(
features
)
logp
=
F
.
log_softmax
(
logits
,
1
)
loss
=
F
.
nll_loss
(
logp
,
labels
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
if
epoch
>=
3
:
dur
.
append
(
time
.
time
()
-
t0
)
print
(
"Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
loss
.
item
(),
np
.
mean
(
dur
),
n_edges
/
np
.
mean
(
dur
)
/
1000
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'GAT'
)
register_data_args
(
parser
)
parser
.
add_argument
(
"--gpu"
,
type
=
int
,
default
=-
1
,
help
=
"Which GPU to use. Set -1 to use CPU."
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
20
,
help
=
"number of training epochs"
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
3
,
help
=
"number of attentional heads to use"
)
parser
.
add_argument
(
"--num-layers"
,
type
=
int
,
default
=
1
,
help
=
"number of hidden layers"
)
parser
.
add_argument
(
"--num-hidden"
,
type
=
int
,
default
=
8
,
help
=
"size of hidden units"
)
parser
.
add_argument
(
"--residual"
,
action
=
"store_false"
,
help
=
"use residual connection"
)
parser
.
add_argument
(
"--in-drop"
,
type
=
float
,
default
=
.
6
,
help
=
"input feature dropout"
)
parser
.
add_argument
(
"--attn-drop"
,
type
=
float
,
default
=
.
6
,
help
=
"attention dropout"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.005
,
help
=
"learning rate"
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
args
)
python/dgl/graph.py
View file @
ee241699
...
@@ -572,8 +572,6 @@ class DGLGraph(DiGraph):
...
@@ -572,8 +572,6 @@ class DGLGraph(DiGraph):
for
vv
in
self
.
pred
[
uu
]
if
__MSG__
in
self
.
edges
[
vv
,
uu
]]
for
vv
in
self
.
pred
[
uu
]
if
__MSG__
in
self
.
edges
[
vv
,
uu
]]
if
len
(
msgs_batch
)
==
0
:
if
len
(
msgs_batch
)
==
0
:
msgs_reduced
=
None
msgs_reduced
=
None
elif
len
(
msgs_batch
)
==
1
:
msgs_reduced
=
msgs_batch
[
0
]
else
:
else
:
msgs_reduced
=
f_reduce
(
_get_repr
(
self
.
nodes
[
uu
]),
msgs_batch
)
msgs_reduced
=
f_reduce
(
_get_repr
(
self
.
nodes
[
uu
]),
msgs_batch
)
# update phase
# update phase
...
@@ -581,17 +579,48 @@ class DGLGraph(DiGraph):
...
@@ -581,17 +579,48 @@ class DGLGraph(DiGraph):
_set_repr
(
self
.
nodes
[
uu
],
ret
)
_set_repr
(
self
.
nodes
[
uu
],
ret
)
def
_batch_recv
(
self
,
v
,
reduce_func
,
update_func
):
def
_batch_recv
(
self
,
v
,
reduce_func
,
update_func
):
v_is_all
=
is_all
(
v
)
f_update
=
update_func
if
v_is_all
:
reordered_v
,
all_reduced_msgs
=
self
.
_batch_reduce
(
v
,
reduce_func
)
if
all_reduced_msgs
is
None
:
# no message; only do recv.
if
is_all
(
v
):
self
.
set_n_repr
(
f_update
(
self
.
get_n_repr
(),
None
))
else
:
self
.
set_n_repr
(
f_update
(
self
.
get_n_repr
(
v
),
None
),
v
)
else
:
# Read the node states in the degree-bucketing order.
reordered_ns
=
self
.
get_n_repr
(
reordered_v
)
new_ns
=
f_update
(
reordered_ns
,
all_reduced_msgs
)
if
is_all
(
v
):
# First do reorder and then replace the whole column.
_
,
indices
=
F
.
sort
(
reordered_v
)
# TODO(minjie): manually convert ids to context.
indices
=
F
.
to_context
(
indices
,
self
.
context
)
if
isinstance
(
new_ns
,
dict
):
for
key
,
val
in
new_ns
.
items
():
self
.
_node_frame
[
key
]
=
F
.
gather_row
(
val
,
indices
)
else
:
self
.
_node_frame
[
__REPR__
]
=
F
.
gather_row
(
new_ns
,
indices
)
else
:
# Use setter to do reorder.
self
.
set_n_repr
(
new_ns
,
reordered_v
)
def
_batch_reduce
(
self
,
v
,
reduce_func
):
if
is_all
(
v
)
and
len
(
self
.
_msg_frame
)
==
0
:
# no message has been sent
return
None
,
None
if
is_all
(
v
):
v
=
list
(
range
(
self
.
number_of_nodes
()))
v
=
list
(
range
(
self
.
number_of_nodes
()))
# sanity checks
# sanity checks
v
=
utils
.
convert_to_id_tensor
(
v
)
v
=
utils
.
convert_to_id_tensor
(
v
)
f_reduce
=
_get_reduce_func
(
reduce_func
)
f_reduce
=
_get_reduce_func
(
reduce_func
)
f_update
=
update_func
# degree bucketing
# degree bucketing
degrees
,
v_buckets
=
scheduler
.
degree_bucketing
(
self
.
msg_graph
,
v
)
degrees
,
v_buckets
=
scheduler
.
degree_bucketing
(
self
.
msg_graph
,
v
)
reduced_msgs
=
[]
reduced_msgs
=
[]
for
deg
,
v_bkt
in
zip
(
degrees
,
v_buckets
):
for
deg
,
v_bkt
in
zip
(
degrees
,
v_buckets
):
if
deg
==
0
:
continue
bkt_len
=
len
(
v_bkt
)
bkt_len
=
len
(
v_bkt
)
uu
,
vv
=
self
.
msg_graph
.
in_edges
(
v_bkt
)
uu
,
vv
=
self
.
msg_graph
.
in_edges
(
v_bkt
)
in_msg_ids
=
self
.
msg_graph
.
get_edge_id
(
uu
,
vv
)
in_msg_ids
=
self
.
msg_graph
.
get_edge_id
(
uu
,
vv
)
...
@@ -611,31 +640,22 @@ class DGLGraph(DiGraph):
...
@@ -611,31 +640,22 @@ class DGLGraph(DiGraph):
dst_reprs
=
self
.
get_n_repr
(
v_bkt
)
dst_reprs
=
self
.
get_n_repr
(
v_bkt
)
reduced_msgs
.
append
(
f_reduce
(
dst_reprs
,
reshaped_in_msgs
))
reduced_msgs
.
append
(
f_reduce
(
dst_reprs
,
reshaped_in_msgs
))
if
len
(
reduced_msgs
)
==
0
:
# no message has been sent to the specified node
return
None
,
None
# TODO: clear partial messages
# TODO: clear partial messages
self
.
clear_messages
()
self
.
clear_messages
()
# Read the node states in the degree-bucketing order.
# Read the node states in the degree-bucketing order.
reordered_v
=
F
.
pack
(
v_buckets
)
reordered_v
=
F
.
pack
(
v_buckets
)
reordered_ns
=
self
.
get_n_repr
(
reordered_v
)
# Pack all reduced msgs together
# Pack all reduced msgs together
if
isinstance
(
reduced_msgs
[
0
],
dict
):
if
isinstance
(
reduced_msgs
[
0
],
dict
):
all_reduced_msgs
=
{
key
:
F
.
pack
(
val
)
for
key
,
val
in
reduced_msgs
.
items
()}
all_reduced_msgs
=
{
key
:
F
.
pack
(
val
)
for
key
,
val
in
reduced_msgs
.
items
()}
else
:
else
:
all_reduced_msgs
=
F
.
pack
(
reduced_msgs
)
all_reduced_msgs
=
F
.
pack
(
reduced_msgs
)
new_ns
=
f_update
(
reordered_ns
,
all_reduced_msgs
)
if
v_is_all
:
return
reordered_v
,
all_reduced_msgs
# First do reorder and then replace the whole column.
_
,
indices
=
F
.
sort
(
reordered_v
)
# TODO(minjie): manually convert ids to context.
indices
=
F
.
to_context
(
indices
,
self
.
context
)
if
isinstance
(
new_ns
,
dict
):
for
key
,
val
in
new_ns
.
items
():
self
.
_node_frame
[
key
]
=
F
.
gather_row
(
val
,
indices
)
else
:
self
.
_node_frame
[
__REPR__
]
=
F
.
gather_row
(
new_ns
,
indices
)
else
:
# Use setter to do reorder.
self
.
set_n_repr
(
new_ns
,
reordered_v
)
def
update_by_edge
(
self
,
def
update_by_edge
(
self
,
u
,
v
,
u
,
v
,
...
...
tests/test_basics2.py
View file @
ee241699
...
@@ -8,7 +8,7 @@ def message_not_called(hu, e_uv):
...
@@ -8,7 +8,7 @@ def message_not_called(hu, e_uv):
assert
False
assert
False
return
hu
return
hu
def
reduce_not_called
(
msgs
):
def
reduce_not_called
(
h
,
msgs
):
assert
False
assert
False
return
0
return
0
...
@@ -70,18 +70,7 @@ def test_recv_no_pred():
...
@@ -70,18 +70,7 @@ def test_recv_no_pred():
g
.
register_update_func
(
update_no_msg
)
g
.
register_update_func
(
update_no_msg
)
g
.
recv
(
0
)
g
.
recv
(
0
)
def
test_skipped_reduce
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_not_called
)
g
.
register_update_func
(
update_func
)
g
.
sendto
(
0
,
1
)
g
.
recv
(
1
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_no_msg_update
()
test_no_msg_update
()
test_double_recv
()
test_double_recv
()
test_recv_no_pred
()
test_recv_no_pred
()
test_skipped_reduce
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment