Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a8a4fcba
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c052791b5fe29ce8a308bf63dda97aa205b729be"
Commit
a8a4fcba
authored
Sep 24, 2018
by
Minjie Wang
Browse files
quickly integrating with tree-lstm example
parent
882e2a7b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
39 additions
and
194 deletions
+39
-194
examples/pytorch/tree_lstm/train.py
examples/pytorch/tree_lstm/train.py
+18
-20
examples/pytorch/tree_lstm/tree_lstm.py
examples/pytorch/tree_lstm/tree_lstm.py
+10
-24
include/dgl/graph_op.h
include/dgl/graph_op.h
+6
-0
python/dgl/cached_graph.py
python/dgl/cached_graph.py
+0
-147
python/dgl/graph.py
python/dgl/graph.py
+4
-2
src/graph/graph.cc
src/graph/graph.cc
+1
-1
No files found.
examples/pytorch/tree_lstm/train.py
View file @
a8a4fcba
...
@@ -8,23 +8,17 @@ from torch.utils.data import DataLoader
...
@@ -8,23 +8,17 @@ from torch.utils.data import DataLoader
import
dgl
import
dgl
import
dgl.data
as
data
import
dgl.data
as
data
import
dgl.ndarray
as
nd
from
tree_lstm
import
TreeLSTM
from
tree_lstm
import
TreeLSTM
def
_batch_to_cuda
(
batch
):
return
data
.
SSTBatch
(
graph
=
batch
.
graph
,
nid_with_word
=
batch
.
nid_with_word
.
cuda
(),
wordid
=
batch
.
wordid
.
cuda
(),
label
=
batch
.
label
.
cuda
())
import
dgl.context
as
ctx
def
tensor_topo_traverse
(
g
,
cuda
,
args
):
def
tensor_topo_traverse
(
g
,
cuda
,
args
):
n
=
g
.
number_of_nodes
()
n
=
g
.
number_of_nodes
()
if
cuda
:
if
cuda
:
adjmat
=
g
.
cached
_graph
.
adj
mat
().
get
(
ctx
.
gpu
(
args
.
gpu
))
adjmat
=
g
.
_graph
.
adj
acency_matrix
().
get
(
nd
.
gpu
(
args
.
gpu
))
mask
=
th
.
ones
((
n
,
1
)).
cuda
()
mask
=
th
.
ones
((
n
,
1
)).
cuda
()
else
:
else
:
adjmat
=
g
.
cached
_graph
.
adj
mat
().
get
(
ctx
.
cpu
())
adjmat
=
g
.
_graph
.
adj
acency_matrix
().
get
(
nd
.
cpu
())
mask
=
th
.
ones
((
n
,
1
))
mask
=
th
.
ones
((
n
,
1
))
degree
=
th
.
spmm
(
adjmat
,
mask
)
degree
=
th
.
spmm
(
adjmat
,
mask
)
while
th
.
sum
(
mask
)
!=
0.
:
while
th
.
sum
(
mask
)
!=
0.
:
...
@@ -39,10 +33,17 @@ def main(args):
...
@@ -39,10 +33,17 @@ def main(args):
cuda
=
args
.
gpu
>=
0
cuda
=
args
.
gpu
>=
0
if
cuda
:
if
cuda
:
th
.
cuda
.
set_device
(
args
.
gpu
)
th
.
cuda
.
set_device
(
args
.
gpu
)
def
_batcher
(
trees
):
bg
=
dgl
.
batch
(
trees
)
if
cuda
:
reprs
=
bg
.
get_n_repr
()
reprs
=
{
key
:
reprs
[
key
].
cuda
()}
bg
.
set_n_repr
(
reprs
)
return
bg
trainset
=
data
.
SST
()
trainset
=
data
.
SST
()
train_loader
=
DataLoader
(
dataset
=
trainset
,
train_loader
=
DataLoader
(
dataset
=
trainset
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
collate_fn
=
data
.
SST
.
batcher
,
collate_fn
=
_
batcher
,
shuffle
=
False
,
shuffle
=
False
,
num_workers
=
0
)
num_workers
=
0
)
#testset = data.SST(mode='test')
#testset = data.SST(mode='test')
...
@@ -69,18 +70,15 @@ def main(args):
...
@@ -69,18 +70,15 @@ def main(args):
dur
=
[]
dur
=
[]
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
t_epoch
=
time
.
time
()
t_epoch
=
time
.
time
()
for
step
,
batch
in
enumerate
(
train_loader
):
for
step
,
graph
in
enumerate
(
train_loader
):
g
=
batch
.
graph
if
cuda
:
batch
=
_batch_to_cuda
(
batch
)
if
step
>=
3
:
if
step
>=
3
:
t0
=
time
.
time
()
t0
=
time
.
time
()
label
=
graph
.
pop_n_repr
(
'y'
)
# traverse graph
# traverse graph
giter
=
list
(
tensor_topo_traverse
(
g
,
False
,
args
))
giter
=
list
(
tensor_topo_traverse
(
g
raph
,
False
,
args
))
logits
=
model
(
batc
h
,
zero_initializer
,
iterator
=
giter
,
train
=
True
)
logits
=
model
(
grap
h
,
zero_initializer
,
iterator
=
giter
,
train
=
True
)
logp
=
F
.
log_softmax
(
logits
,
1
)
logp
=
F
.
log_softmax
(
logits
,
1
)
loss
=
F
.
nll_loss
(
logp
,
batch
.
label
)
loss
=
F
.
nll_loss
(
logp
,
label
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
...
@@ -89,11 +87,11 @@ def main(args):
...
@@ -89,11 +87,11 @@ def main(args):
if
step
>
0
and
step
%
args
.
log_every
==
0
:
if
step
>
0
and
step
%
args
.
log_every
==
0
:
pred
=
th
.
argmax
(
logits
,
1
)
pred
=
th
.
argmax
(
logits
,
1
)
acc
=
th
.
sum
(
th
.
eq
(
batch
.
label
,
pred
))
acc
=
th
.
sum
(
th
.
eq
(
label
,
pred
))
mean_dur
=
np
.
mean
(
dur
)
mean_dur
=
np
.
mean
(
dur
)
print
(
"Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
print
(
"Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}"
.
format
(
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}"
.
format
(
epoch
,
step
,
loss
.
item
(),
acc
.
item
()
/
len
(
batch
.
label
)
,
epoch
,
step
,
loss
.
item
(),
acc
.
item
()
/
args
.
batch_size
,
mean_dur
,
args
.
batch_size
/
mean_dur
))
mean_dur
,
args
.
batch_size
/
mean_dur
))
print
(
"Epoch time(s):"
,
time
.
time
()
-
t_epoch
)
print
(
"Epoch time(s):"
,
time
.
time
()
-
t_epoch
)
...
...
examples/pytorch/tree_lstm/tree_lstm.py
View file @
a8a4fcba
...
@@ -10,23 +10,7 @@ import torch as th
...
@@ -10,23 +10,7 @@ import torch as th
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
def
topological_traverse
(
G
):
import
dgl
indegree_map
=
{
v
:
d
for
v
,
d
in
G
.
in_degree
()
if
d
>
0
}
# These nodes have zero indegree and ready to be returned.
zero_indegree
=
[
v
for
v
,
d
in
G
.
in_degree
()
if
d
==
0
]
while
True
:
yield
zero_indegree
next_zero_indegree
=
[]
while
zero_indegree
:
node
=
zero_indegree
.
pop
()
for
_
,
child
in
G
.
edges
(
node
):
indegree_map
[
child
]
-=
1
if
indegree_map
[
child
]
==
0
:
next_zero_indegree
.
append
(
child
)
del
indegree_map
[
child
]
if
len
(
next_zero_indegree
)
==
0
:
break
zero_indegree
=
next_zero_indegree
class
ChildSumTreeLSTMCell
(
nn
.
Module
):
class
ChildSumTreeLSTMCell
(
nn
.
Module
):
def
__init__
(
self
,
x_size
,
h_size
):
def
__init__
(
self
,
x_size
,
h_size
):
...
@@ -83,13 +67,13 @@ class TreeLSTM(nn.Module):
...
@@ -83,13 +67,13 @@ class TreeLSTM(nn.Module):
else
:
else
:
raise
RuntimeError
(
'Unknown cell type:'
,
cell_type
)
raise
RuntimeError
(
'Unknown cell type:'
,
cell_type
)
def
forward
(
self
,
batc
h
,
zero_initializer
,
h
=
None
,
c
=
None
,
iterator
=
None
,
train
=
True
):
def
forward
(
self
,
grap
h
,
zero_initializer
,
h
=
None
,
c
=
None
,
iterator
=
None
,
train
=
True
):
"""Compute tree-lstm prediction given a batch.
"""Compute tree-lstm prediction given a batch.
Parameters
Parameters
----------
----------
batc
h : dgl.
data.SSTBatc
h
grap
h : dgl.
DGLGrap
h
The
data
batch.
The batch
ed trees
.
zero_initializer : callable
zero_initializer : callable
Function to return zero value tensor.
Function to return zero value tensor.
h : Tensor, optional
h : Tensor, optional
...
@@ -104,15 +88,17 @@ class TreeLSTM(nn.Module):
...
@@ -104,15 +88,17 @@ class TreeLSTM(nn.Module):
logits : Tensor
logits : Tensor
The prediction of each node.
The prediction of each node.
"""
"""
g
=
batch
.
graph
g
=
graph
n
=
g
.
number_of_nodes
()
n
=
g
.
number_of_nodes
()
g
.
register_message_func
(
self
.
cell
.
message_func
,
batchable
=
True
)
g
.
register_message_func
(
self
.
cell
.
message_func
,
batchable
=
True
)
g
.
register_reduce_func
(
self
.
cell
.
reduce_func
,
batchable
=
True
)
g
.
register_reduce_func
(
self
.
cell
.
reduce_func
,
batchable
=
True
)
g
.
register_apply_node_func
(
self
.
cell
.
apply_func
,
batchable
=
True
)
g
.
register_apply_node_func
(
self
.
cell
.
apply_func
,
batchable
=
True
)
# feed embedding
# feed embedding
embeds
=
self
.
embedding
(
batch
.
wordid
)
wordid
=
g
.
pop_n_repr
(
'x'
)
x
=
zero_initializer
((
n
,
self
.
x_size
))
mask
=
(
wordid
!=
dgl
.
data
.
SST
.
PAD_WORD
)
x
=
x
.
index_copy
(
0
,
batch
.
nid_with_word
,
embeds
)
wordid
=
wordid
*
mask
.
long
()
embeds
=
self
.
embedding
(
wordid
)
x
=
embeds
*
th
.
unsqueeze
(
mask
,
1
).
float
()
if
h
is
None
:
if
h
is
None
:
h
=
zero_initializer
((
n
,
self
.
h_size
))
h
=
zero_initializer
((
n
,
self
.
h_size
))
h_tild
=
zero_initializer
((
n
,
self
.
h_size
))
h_tild
=
zero_initializer
((
n
,
self
.
h_size
))
...
...
include/dgl/graph_op.h
View file @
a8a4fcba
...
@@ -8,6 +8,12 @@ namespace dgl {
...
@@ -8,6 +8,12 @@ namespace dgl {
class
GraphOp
{
class
GraphOp
{
public:
public:
/*!
* \brief Return the line graph.
* \param graph The input graph.
* \return the line graph
*/
static
Graph
LineGraph
(
const
Graph
*
graph
);
/*!
/*!
* \brief Return a disjoint union of the input graphs.
* \brief Return a disjoint union of the input graphs.
*
*
...
...
python/dgl/cached_graph.py
deleted
100644 → 0
View file @
882e2a7b
"""High-performance graph structure query component.
TODO: Currently implemented by igraph. Should replace with more efficient
solution later.
"""
from
__future__
import
absolute_import
import
igraph
from
.
import
backend
as
F
from
.backend
import
Tensor
from
.
import
utils
class
CachedGraph
:
def
__init__
(
self
):
self
.
_graph
=
igraph
.
Graph
(
directed
=
True
)
self
.
_freeze
=
False
def
add_nodes
(
self
,
num_nodes
):
if
self
.
_freeze
:
raise
RuntimeError
(
'Freezed cached graph cannot be mutated.'
)
self
.
_graph
.
add_vertices
(
num_nodes
)
def
add_edge
(
self
,
u
,
v
):
if
self
.
_freeze
:
raise
RuntimeError
(
'Freezed cached graph cannot be mutated.'
)
self
.
_graph
.
add_edge
(
u
,
v
)
def
add_edges
(
self
,
u
,
v
):
if
self
.
_freeze
:
raise
RuntimeError
(
'Freezed cached graph cannot be mutated.'
)
# The edge will be assigned ids equal to the order.
uvs
=
list
(
utils
.
edge_iter
(
u
,
v
))
self
.
_graph
.
add_edges
(
uvs
)
def
get_edge_id
(
self
,
u
,
v
):
uvs
=
list
(
utils
.
edge_iter
(
u
,
v
))
eids
=
self
.
_graph
.
get_eids
(
uvs
)
return
utils
.
toindex
(
eids
)
def
in_edges
(
self
,
v
):
"""Get in-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no in-edges.
"""
src
=
[]
dst
=
[]
orphan
=
[]
for
vv
in
utils
.
node_iter
(
v
):
uu
=
self
.
_graph
.
predecessors
(
vv
)
if
len
(
uu
)
==
0
:
orphan
.
append
(
vv
)
else
:
src
+=
uu
dst
+=
[
vv
]
*
len
(
uu
)
src
=
utils
.
toindex
(
src
)
dst
=
utils
.
toindex
(
dst
)
orphan
=
utils
.
toindex
(
orphan
)
return
src
,
dst
,
orphan
def
out_edges
(
self
,
u
):
"""Get out-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no out-edges.
"""
src
=
[]
dst
=
[]
orphan
=
[]
for
uu
in
utils
.
node_iter
(
u
):
vv
=
self
.
_graph
.
successors
(
uu
)
if
len
(
vv
)
==
0
:
orphan
.
append
(
uu
)
else
:
src
+=
[
uu
]
*
len
(
vv
)
dst
+=
vv
src
=
utils
.
toindex
(
src
)
dst
=
utils
.
toindex
(
dst
)
orphan
=
utils
.
toindex
(
orphan
)
return
src
,
dst
,
orphan
def
in_degrees
(
self
,
v
):
degs
=
self
.
_graph
.
indegree
(
list
(
v
))
return
utils
.
toindex
(
degs
)
def
num_edges
(
self
):
return
self
.
_graph
.
ecount
()
@
utils
.
cached_member
def
edges
(
self
):
elist
=
self
.
_graph
.
get_edgelist
()
src
=
[
u
for
u
,
_
in
elist
]
dst
=
[
v
for
_
,
v
in
elist
]
src
=
utils
.
toindex
(
src
)
dst
=
utils
.
toindex
(
dst
)
return
src
,
dst
@
utils
.
cached_member
def
adjmat
(
self
):
"""Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension
represents the src nodes.
"""
elist
=
self
.
_graph
.
get_edgelist
()
src
=
F
.
tensor
([
u
for
u
,
_
in
elist
],
dtype
=
F
.
int64
)
dst
=
F
.
tensor
([
v
for
_
,
v
in
elist
],
dtype
=
F
.
int64
)
src
=
F
.
unsqueeze
(
src
,
0
)
dst
=
F
.
unsqueeze
(
dst
,
0
)
idx
=
F
.
pack
([
dst
,
src
])
n
=
self
.
_graph
.
vcount
()
dat
=
F
.
ones
((
len
(
elist
),))
mat
=
F
.
sparse_tensor
(
idx
,
dat
,
[
n
,
n
])
return
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
to_context
(
mat
,
ctx
))
def
freeze
(
self
):
self
.
_freeze
=
True
def
create_cached_graph
(
dglgraph
):
cg
=
CachedGraph
()
cg
.
add_nodes
(
dglgraph
.
number_of_nodes
())
cg
.
_graph
.
add_edges
(
dglgraph
.
edge_list
)
cg
.
freeze
()
return
cg
python/dgl/graph.py
View file @
a8a4fcba
...
@@ -48,6 +48,7 @@ class DGLGraph(object):
...
@@ -48,6 +48,7 @@ class DGLGraph(object):
# msg graph & frame
# msg graph & frame
self
.
_msg_graph
=
create_graph_index
()
self
.
_msg_graph
=
create_graph_index
()
self
.
_msg_frame
=
FrameRef
()
self
.
_msg_frame
=
FrameRef
()
self
.
reset_messages
()
# registered functions
# registered functions
self
.
_message_func
=
(
None
,
None
)
self
.
_message_func
=
(
None
,
None
)
self
.
_reduce_func
=
(
None
,
None
)
self
.
_reduce_func
=
(
None
,
None
)
...
@@ -112,7 +113,7 @@ class DGLGraph(object):
...
@@ -112,7 +113,7 @@ class DGLGraph(object):
self
.
_msg_graph
.
clear
()
self
.
_msg_graph
.
clear
()
self
.
_msg_frame
.
clear
()
self
.
_msg_frame
.
clear
()
def
clear
_messages
(
self
):
def
reset
_messages
(
self
):
"""Clear all messages."""
"""Clear all messages."""
self
.
_msg_graph
.
clear
()
self
.
_msg_graph
.
clear
()
self
.
_msg_frame
.
clear
()
self
.
_msg_frame
.
clear
()
...
@@ -447,6 +448,7 @@ class DGLGraph(object):
...
@@ -447,6 +448,7 @@ class DGLGraph(object):
self
.
clear
()
self
.
clear
()
self
.
_graph
.
from_networkx
(
nx_graph
)
self
.
_graph
.
from_networkx
(
nx_graph
)
self
.
_msg_graph
.
add_nodes
(
self
.
_graph
.
number_of_nodes
())
self
.
_msg_graph
.
add_nodes
(
self
.
_graph
.
number_of_nodes
())
# copy attributes
def
_batcher
(
lst
):
def
_batcher
(
lst
):
if
isinstance
(
lst
[
0
],
Tensor
):
if
isinstance
(
lst
[
0
],
Tensor
):
return
F
.
pack
([
F
.
unsqueeze
(
x
,
0
)
for
x
in
lst
])
return
F
.
pack
([
F
.
unsqueeze
(
x
,
0
)
for
x
in
lst
])
...
@@ -1078,7 +1080,7 @@ class DGLGraph(object):
...
@@ -1078,7 +1080,7 @@ class DGLGraph(object):
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
# TODO: clear partial messages
# TODO: clear partial messages
self
.
clear
_messages
()
self
.
reset
_messages
()
# Pack all reducer results together
# Pack all reducer results together
reordered_v
=
F
.
pack
(
reordered_v
)
reordered_v
=
F
.
pack
(
reordered_v
)
...
...
src/graph/graph.cc
View file @
a8a4fcba
...
@@ -20,7 +20,7 @@ void Graph::AddVertices(uint64_t num_vertices) {
...
@@ -20,7 +20,7 @@ void Graph::AddVertices(uint64_t num_vertices) {
void
Graph
::
AddEdge
(
dgl_id_t
src
,
dgl_id_t
dst
)
{
void
Graph
::
AddEdge
(
dgl_id_t
src
,
dgl_id_t
dst
)
{
CHECK
(
!
read_only_
)
<<
"Graph is read-only. Mutations are not allowed."
;
CHECK
(
!
read_only_
)
<<
"Graph is read-only. Mutations are not allowed."
;
CHECK
(
HasVertex
(
src
)
&&
HasVertex
(
dst
))
CHECK
(
HasVertex
(
src
)
&&
HasVertex
(
dst
))
<<
"In
valid vertices: "
<<
src
<<
" "
<<
dst
;
<<
"Invalid vertices:
src=
"
<<
src
<<
"
dst=
"
<<
dst
;
dgl_id_t
eid
=
num_edges_
++
;
dgl_id_t
eid
=
num_edges_
++
;
adjlist_
[
src
].
succ
.
push_back
(
dst
);
adjlist_
[
src
].
succ
.
push_back
(
dst
);
adjlist_
[
src
].
edge_id
.
push_back
(
eid
);
adjlist_
[
src
].
edge_id
.
push_back
(
eid
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment