Commit 1eb17bb0 authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[Model] Adapt Tree-LSTM to new interface (#122)

* tree_lstm (new interface)

* simplify pop
parent 23191674
......@@ -15,10 +15,10 @@ from tree_lstm import TreeLSTM
def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes()
if cuda:
adjmat = g._graph.adjacency_matrix().get(nd.gpu(args.gpu))
adjmat = g._graph.adjacency_matrix().get(th.device('cuda:{}'.format(cuda)))
mask = th.ones((n, 1)).cuda()
else:
adjmat = g._graph.adjacency_matrix().get(nd.cpu())
adjmat = g._graph.adjacency_matrix().get(th.device('cpu'))
mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.:
......@@ -36,9 +36,8 @@ def main(args):
def _batcher(trees):
bg = dgl.batch(trees)
if cuda:
reprs = bg.get_n_repr()
reprs = {key : val.cuda() for key, val in reprs.items()}
bg.set_n_repr(reprs)
for key in bg.node_attr_schemes().keys():
bg.ndata[key] = bg.ndata[key].cuda()
return bg
trainset = data.SST()
train_loader = DataLoader(dataset=trainset,
......@@ -73,7 +72,7 @@ def main(args):
for step, graph in enumerate(train_loader):
if step >= 3:
t0 = time.time()
label = graph.pop_n_repr('y')
label = graph.ndata.pop('y')
# traverse graph
giter = list(tensor_topo_traverse(graph, False, args))
logits = model(graph, zero_initializer, iterator=giter, train=True)
......
......@@ -22,27 +22,27 @@ class ChildSumTreeLSTMCell(nn.Module):
self.rt = 0.
self.ut = 0.
def message_func(self, src, edge):
return {'h' : src['h'], 'c' : src['c']}
def message_func(self, edges):
return {'h' : edges.src['h'], 'c' : edges.src['c']}
def reduce_func(self, node, msgs):
def reduce_func(self, nodes):
# equation (2)
h_tild = th.sum(msgs['h'], 1)
h_tild = th.sum(nodes.mailbox['h'], 1)
# equation (4)
wx = self.W_f(node['x']).unsqueeze(1) # shape: (B, 1, H)
uh = self.U_f(msgs['h']) # shape: (B, deg, H)
wx = self.W_f(nodes.data['x']).unsqueeze(1) # shape: (B, 1, H)
uh = self.U_f(nodes.mailbox['h']) # shape: (B, deg, H)
f = th.sigmoid(wx + uh) # shape: (B, deg, H)
# equation (7) second term
c_tild = th.sum(f * msgs['c'], 1)
c_tild = th.sum(f * nodes.mailbox['c'], 1)
return {'h_tild' : h_tild, 'c_tild' : c_tild}
def apply_func(self, node):
def apply_func(self, nodes):
# equation (3), (5), (6)
iou = self.W_iou(node['x']) + self.U_iou(node['h_tild'])
iou = self.W_iou(nodes.data['x']) + self.U_iou(nodes.data['h_tild'])
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (7)
c = i * u + node['c_tild']
c = i * u + nodes.data['c_tild']
# equation (8)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
......@@ -98,14 +98,15 @@ class TreeLSTM(nn.Module):
mask = (wordid != dgl.data.SST.PAD_WORD)
wordid = wordid * mask.long()
embeds = self.embedding(wordid)
x = embeds * th.unsqueeze(mask, 1).float()
g.ndata['x'] = embeds * th.unsqueeze(mask, 1).float()
if h is None:
h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size))
g.ndata['h'] = h
g.ndata['h_tild'] = zero_initializer((n, self.h_size))
if c is None:
c = zero_initializer((n, self.h_size))
c_tild = zero_initializer((n, self.h_size))
g.set_n_repr({'x' : x, 'h' : h, 'c' : c, 'h_tild' : h_tild, 'c_tild' : c_tild})
g.ndata['c'] = c
g.ndata['c_tild'] = zero_initializer((n, self.h_size))
# TODO(minjie): potential bottleneck
if iterator is None:
g.propagate('topo')
......@@ -113,7 +114,7 @@ class TreeLSTM(nn.Module):
for frontier in iterator:
g.pull(frontier)
# compute logits
h = g.pop_n_repr('h')
h = g.ndata.pop('h')
h = self.dropout(h)
logits = self.linear(h)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment