Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1eb17bb0
Commit
1eb17bb0
authored
Nov 08, 2018
by
Zihao Ye
Committed by
Minjie Wang
Nov 08, 2018
Browse files
[Model] Adapt Tree-LSTM to new interface (#122)
* tree_lstm (new interface) * simplify pop
parent
23191674
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
21 deletions
+21
-21
examples/pytorch/tree_lstm/train.py
examples/pytorch/tree_lstm/train.py
+5
-6
examples/pytorch/tree_lstm/tree_lstm.py
examples/pytorch/tree_lstm/tree_lstm.py
+16
-15
No files found.
examples/pytorch/tree_lstm/train.py
View file @
1eb17bb0
...
@@ -15,10 +15,10 @@ from tree_lstm import TreeLSTM
...
@@ -15,10 +15,10 @@ from tree_lstm import TreeLSTM
def
tensor_topo_traverse
(
g
,
cuda
,
args
):
def
tensor_topo_traverse
(
g
,
cuda
,
args
):
n
=
g
.
number_of_nodes
()
n
=
g
.
number_of_nodes
()
if
cuda
:
if
cuda
:
adjmat
=
g
.
_graph
.
adjacency_matrix
().
get
(
nd
.
gpu
(
args
.
gpu
))
adjmat
=
g
.
_graph
.
adjacency_matrix
().
get
(
th
.
device
(
'cuda:{}'
.
format
(
cuda
)
))
mask
=
th
.
ones
((
n
,
1
)).
cuda
()
mask
=
th
.
ones
((
n
,
1
)).
cuda
()
else
:
else
:
adjmat
=
g
.
_graph
.
adjacency_matrix
().
get
(
nd
.
cpu
(
))
adjmat
=
g
.
_graph
.
adjacency_matrix
().
get
(
th
.
device
(
'
cpu
'
))
mask
=
th
.
ones
((
n
,
1
))
mask
=
th
.
ones
((
n
,
1
))
degree
=
th
.
spmm
(
adjmat
,
mask
)
degree
=
th
.
spmm
(
adjmat
,
mask
)
while
th
.
sum
(
mask
)
!=
0.
:
while
th
.
sum
(
mask
)
!=
0.
:
...
@@ -36,9 +36,8 @@ def main(args):
...
@@ -36,9 +36,8 @@ def main(args):
def
_batcher
(
trees
):
def
_batcher
(
trees
):
bg
=
dgl
.
batch
(
trees
)
bg
=
dgl
.
batch
(
trees
)
if
cuda
:
if
cuda
:
reprs
=
bg
.
get_n_repr
()
for
key
in
bg
.
node_attr_schemes
().
keys
():
reprs
=
{
key
:
val
.
cuda
()
for
key
,
val
in
reprs
.
items
()}
bg
.
ndata
[
key
]
=
bg
.
ndata
[
key
].
cuda
()
bg
.
set_n_repr
(
reprs
)
return
bg
return
bg
trainset
=
data
.
SST
()
trainset
=
data
.
SST
()
train_loader
=
DataLoader
(
dataset
=
trainset
,
train_loader
=
DataLoader
(
dataset
=
trainset
,
...
@@ -73,7 +72,7 @@ def main(args):
...
@@ -73,7 +72,7 @@ def main(args):
for
step
,
graph
in
enumerate
(
train_loader
):
for
step
,
graph
in
enumerate
(
train_loader
):
if
step
>=
3
:
if
step
>=
3
:
t0
=
time
.
time
()
t0
=
time
.
time
()
label
=
graph
.
pop_n_repr
(
'y'
)
label
=
graph
.
ndata
.
pop
(
'y'
)
# traverse graph
# traverse graph
giter
=
list
(
tensor_topo_traverse
(
graph
,
False
,
args
))
giter
=
list
(
tensor_topo_traverse
(
graph
,
False
,
args
))
logits
=
model
(
graph
,
zero_initializer
,
iterator
=
giter
,
train
=
True
)
logits
=
model
(
graph
,
zero_initializer
,
iterator
=
giter
,
train
=
True
)
...
...
examples/pytorch/tree_lstm/tree_lstm.py
View file @
1eb17bb0
...
@@ -22,27 +22,27 @@ class ChildSumTreeLSTMCell(nn.Module):
...
@@ -22,27 +22,27 @@ class ChildSumTreeLSTMCell(nn.Module):
self
.
rt
=
0.
self
.
rt
=
0.
self
.
ut
=
0.
self
.
ut
=
0.
def
message_func
(
self
,
src
,
edge
):
def
message_func
(
self
,
edge
s
):
return
{
'h'
:
src
[
'h'
],
'c'
:
src
[
'c'
]}
return
{
'h'
:
edges
.
src
[
'h'
],
'c'
:
edges
.
src
[
'c'
]}
def
reduce_func
(
self
,
node
,
msg
s
):
def
reduce_func
(
self
,
nodes
):
# equation (2)
# equation (2)
h_tild
=
th
.
sum
(
msgs
[
'h'
],
1
)
h_tild
=
th
.
sum
(
nodes
.
mailbox
[
'h'
],
1
)
# equation (4)
# equation (4)
wx
=
self
.
W_f
(
node
[
'x'
]).
unsqueeze
(
1
)
# shape: (B, 1, H)
wx
=
self
.
W_f
(
node
s
.
data
[
'x'
]).
unsqueeze
(
1
)
# shape: (B, 1, H)
uh
=
self
.
U_f
(
msgs
[
'h'
])
# shape: (B, deg, H)
uh
=
self
.
U_f
(
nodes
.
mailbox
[
'h'
])
# shape: (B, deg, H)
f
=
th
.
sigmoid
(
wx
+
uh
)
# shape: (B, deg, H)
f
=
th
.
sigmoid
(
wx
+
uh
)
# shape: (B, deg, H)
# equation (7) second term
# equation (7) second term
c_tild
=
th
.
sum
(
f
*
msgs
[
'c'
],
1
)
c_tild
=
th
.
sum
(
f
*
nodes
.
mailbox
[
'c'
],
1
)
return
{
'h_tild'
:
h_tild
,
'c_tild'
:
c_tild
}
return
{
'h_tild'
:
h_tild
,
'c_tild'
:
c_tild
}
def
apply_func
(
self
,
node
):
def
apply_func
(
self
,
node
s
):
# equation (3), (5), (6)
# equation (3), (5), (6)
iou
=
self
.
W_iou
(
node
[
'x'
])
+
self
.
U_iou
(
node
[
'h_tild'
])
iou
=
self
.
W_iou
(
node
s
.
data
[
'x'
])
+
self
.
U_iou
(
node
s
.
data
[
'h_tild'
])
i
,
o
,
u
=
th
.
chunk
(
iou
,
3
,
1
)
i
,
o
,
u
=
th
.
chunk
(
iou
,
3
,
1
)
i
,
o
,
u
=
th
.
sigmoid
(
i
),
th
.
sigmoid
(
o
),
th
.
tanh
(
u
)
i
,
o
,
u
=
th
.
sigmoid
(
i
),
th
.
sigmoid
(
o
),
th
.
tanh
(
u
)
# equation (7)
# equation (7)
c
=
i
*
u
+
node
[
'c_tild'
]
c
=
i
*
u
+
node
s
.
data
[
'c_tild'
]
# equation (8)
# equation (8)
h
=
o
*
th
.
tanh
(
c
)
h
=
o
*
th
.
tanh
(
c
)
return
{
'h'
:
h
,
'c'
:
c
}
return
{
'h'
:
h
,
'c'
:
c
}
...
@@ -98,14 +98,15 @@ class TreeLSTM(nn.Module):
...
@@ -98,14 +98,15 @@ class TreeLSTM(nn.Module):
mask
=
(
wordid
!=
dgl
.
data
.
SST
.
PAD_WORD
)
mask
=
(
wordid
!=
dgl
.
data
.
SST
.
PAD_WORD
)
wordid
=
wordid
*
mask
.
long
()
wordid
=
wordid
*
mask
.
long
()
embeds
=
self
.
embedding
(
wordid
)
embeds
=
self
.
embedding
(
wordid
)
x
=
embeds
*
th
.
unsqueeze
(
mask
,
1
).
float
()
g
.
ndata
[
'x'
]
=
embeds
*
th
.
unsqueeze
(
mask
,
1
).
float
()
if
h
is
None
:
if
h
is
None
:
h
=
zero_initializer
((
n
,
self
.
h_size
))
h
=
zero_initializer
((
n
,
self
.
h_size
))
h_tild
=
zero_initializer
((
n
,
self
.
h_size
))
g
.
ndata
[
'h'
]
=
h
g
.
ndata
[
'h_tild'
]
=
zero_initializer
((
n
,
self
.
h_size
))
if
c
is
None
:
if
c
is
None
:
c
=
zero_initializer
((
n
,
self
.
h_size
))
c
=
zero_initializer
((
n
,
self
.
h_size
))
c_tild
=
zero_initializer
((
n
,
self
.
h_size
))
g
.
ndata
[
'c'
]
=
c
g
.
set_n_repr
({
'x'
:
x
,
'h'
:
h
,
'c'
:
c
,
'h_tild'
:
h_tild
,
'c_tild'
:
c_tild
}
)
g
.
ndata
[
'c_tild'
]
=
zero_initializer
((
n
,
self
.
h_size
)
)
# TODO(minjie): potential bottleneck
# TODO(minjie): potential bottleneck
if
iterator
is
None
:
if
iterator
is
None
:
g
.
propagate
(
'topo'
)
g
.
propagate
(
'topo'
)
...
@@ -113,7 +114,7 @@ class TreeLSTM(nn.Module):
...
@@ -113,7 +114,7 @@ class TreeLSTM(nn.Module):
for
frontier
in
iterator
:
for
frontier
in
iterator
:
g
.
pull
(
frontier
)
g
.
pull
(
frontier
)
# compute logits
# compute logits
h
=
g
.
pop_n_repr
(
'h'
)
h
=
g
.
ndata
.
pop
(
'h'
)
h
=
self
.
dropout
(
h
)
h
=
self
.
dropout
(
h
)
logits
=
self
.
linear
(
h
)
logits
=
self
.
linear
(
h
)
return
logits
return
logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment