Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3564fdc5
Unverified
Commit
3564fdc5
authored
Dec 06, 2018
by
Minjie Wang
Committed by
GitHub
Dec 06, 2018
Browse files
[Bugfix][Model] fix treelstm model (#274)
* fix bug after moving batcher out of dgl.data * disable mx utest
parent
f491d6b9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
5 deletions
+7
-5
examples/pytorch/tree_lstm/train.py
examples/pytorch/tree_lstm/train.py
+3
-1
tests/mxnet/test_propagate.py
tests/mxnet/test_propagate.py
+2
-2
tests/mxnet/test_traversal.py
tests/mxnet/test_traversal.py
+2
-2
No files found.
examples/pytorch/tree_lstm/train.py
View file @
3564fdc5
import
argparse
import
argparse
import
collections
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
torch
as
th
import
torch
as
th
...
@@ -12,7 +13,8 @@ from dgl.data.tree import SST
...
@@ -12,7 +13,8 @@ from dgl.data.tree import SST
from
tree_lstm
import
TreeLSTM
from
tree_lstm
import
TreeLSTM
def
batcher
(
dev
):
SSTBatch
=
collections
.
namedtuple
(
'SSTBatch'
,
[
'graph'
,
'mask'
,
'wordid'
,
'label'
])
def
batcher
(
device
):
def
batcher_dev
(
batch
):
def
batcher_dev
(
batch
):
batch_trees
=
dgl
.
batch
(
batch
)
batch_trees
=
dgl
.
batch
(
batch
)
return
SSTBatch
(
graph
=
batch_trees
,
return
SSTBatch
(
graph
=
batch_trees
,
...
...
tests/mxnet/test_propagate.py
View file @
3564fdc5
...
@@ -23,7 +23,7 @@ def test_prop_nodes_bfs():
...
@@ -23,7 +23,7 @@ def test_prop_nodes_bfs():
assert
np
.
allclose
(
g
.
ndata
[
'x'
].
asnumpy
(),
assert
np
.
allclose
(
g
.
ndata
[
'x'
].
asnumpy
(),
np
.
array
([[
2.
,
2.
],
[
4.
,
4.
],
[
6.
,
6.
],
[
8.
,
8.
],
[
9.
,
9.
]]))
np
.
array
([[
2.
,
2.
],
[
4.
,
4.
],
[
6.
,
6.
],
[
8.
,
8.
],
[
9.
,
9.
]]))
def
test_prop_edges_dfs
():
def
_
test_prop_edges_dfs
():
g
=
dgl
.
DGLGraph
(
nx
.
path_graph
(
5
))
g
=
dgl
.
DGLGraph
(
nx
.
path_graph
(
5
))
g
.
register_message_func
(
mfunc
)
g
.
register_message_func
(
mfunc
)
g
.
register_reduce_func
(
rfunc
)
g
.
register_reduce_func
(
rfunc
)
...
@@ -70,5 +70,5 @@ def test_prop_nodes_topo():
...
@@ -70,5 +70,5 @@ def test_prop_nodes_topo():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_prop_nodes_bfs
()
test_prop_nodes_bfs
()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_prop_edges_dfs()
#
_
test_prop_edges_dfs()
test_prop_nodes_topo
()
test_prop_nodes_topo
()
tests/mxnet/test_traversal.py
View file @
3564fdc5
...
@@ -84,7 +84,7 @@ def test_topological_nodes(n=1000):
...
@@ -84,7 +84,7 @@ def test_topological_nodes(n=1000):
assert
all
(
toset
(
x
)
==
toset
(
y
)
for
x
,
y
in
zip
(
layers_dgl
,
layers_spmv
))
assert
all
(
toset
(
x
)
==
toset
(
y
)
for
x
,
y
in
zip
(
layers_dgl
,
layers_spmv
))
DFS_LABEL_NAMES
=
[
'forward'
,
'reverse'
,
'nontree'
]
DFS_LABEL_NAMES
=
[
'forward'
,
'reverse'
,
'nontree'
]
def
test_dfs_labeled_edges
(
n
=
1000
,
example
=
False
):
def
_
test_dfs_labeled_edges
(
n
=
1000
,
example
=
False
):
dgl_g
=
dgl
.
DGLGraph
()
dgl_g
=
dgl
.
DGLGraph
()
dgl_g
.
add_nodes
(
6
)
dgl_g
.
add_nodes
(
6
)
dgl_g
.
add_edges
([
0
,
1
,
0
,
3
,
3
],
[
1
,
2
,
2
,
4
,
5
])
dgl_g
.
add_edges
([
0
,
1
,
0
,
3
,
3
],
[
1
,
2
,
2
,
4
,
5
])
...
@@ -124,4 +124,4 @@ if __name__ == '__main__':
...
@@ -124,4 +124,4 @@ if __name__ == '__main__':
test_bfs
()
test_bfs
()
test_topological_nodes
()
test_topological_nodes
()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#test_dfs_labeled_edges()
#
_
test_dfs_labeled_edges()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment