Commit 1eb17bb0 authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[Model] Adapt Tree-LSTM to new interface (#122)

* tree_lstm (new interface)

* simplify pop
parent 23191674
...@@ -15,10 +15,10 @@ from tree_lstm import TreeLSTM ...@@ -15,10 +15,10 @@ from tree_lstm import TreeLSTM
def tensor_topo_traverse(g, cuda, args): def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes() n = g.number_of_nodes()
if cuda: if cuda:
adjmat = g._graph.adjacency_matrix().get(nd.gpu(args.gpu)) adjmat = g._graph.adjacency_matrix().get(th.device('cuda:{}'.format(cuda)))
mask = th.ones((n, 1)).cuda() mask = th.ones((n, 1)).cuda()
else: else:
adjmat = g._graph.adjacency_matrix().get(nd.cpu()) adjmat = g._graph.adjacency_matrix().get(th.device('cpu'))
mask = th.ones((n, 1)) mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask) degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.: while th.sum(mask) != 0.:
...@@ -36,9 +36,8 @@ def main(args): ...@@ -36,9 +36,8 @@ def main(args):
def _batcher(trees): def _batcher(trees):
bg = dgl.batch(trees) bg = dgl.batch(trees)
if cuda: if cuda:
reprs = bg.get_n_repr() for key in bg.node_attr_schemes().keys():
reprs = {key : val.cuda() for key, val in reprs.items()} bg.ndata[key] = bg.ndata[key].cuda()
bg.set_n_repr(reprs)
return bg return bg
trainset = data.SST() trainset = data.SST()
train_loader = DataLoader(dataset=trainset, train_loader = DataLoader(dataset=trainset,
...@@ -73,7 +72,7 @@ def main(args): ...@@ -73,7 +72,7 @@ def main(args):
for step, graph in enumerate(train_loader): for step, graph in enumerate(train_loader):
if step >= 3: if step >= 3:
t0 = time.time() t0 = time.time()
label = graph.pop_n_repr('y') label = graph.ndata.pop('y')
# traverse graph # traverse graph
giter = list(tensor_topo_traverse(graph, False, args)) giter = list(tensor_topo_traverse(graph, False, args))
logits = model(graph, zero_initializer, iterator=giter, train=True) logits = model(graph, zero_initializer, iterator=giter, train=True)
......
...@@ -22,27 +22,27 @@ class ChildSumTreeLSTMCell(nn.Module): ...@@ -22,27 +22,27 @@ class ChildSumTreeLSTMCell(nn.Module):
self.rt = 0. self.rt = 0.
self.ut = 0. self.ut = 0.
def message_func(self, src, edge): def message_func(self, edges):
return {'h' : src['h'], 'c' : src['c']} return {'h' : edges.src['h'], 'c' : edges.src['c']}
def reduce_func(self, node, msgs): def reduce_func(self, nodes):
# equation (2) # equation (2)
h_tild = th.sum(msgs['h'], 1) h_tild = th.sum(nodes.mailbox['h'], 1)
# equation (4) # equation (4)
wx = self.W_f(node['x']).unsqueeze(1) # shape: (B, 1, H) wx = self.W_f(nodes.data['x']).unsqueeze(1) # shape: (B, 1, H)
uh = self.U_f(msgs['h']) # shape: (B, deg, H) uh = self.U_f(nodes.mailbox['h']) # shape: (B, deg, H)
f = th.sigmoid(wx + uh) # shape: (B, deg, H) f = th.sigmoid(wx + uh) # shape: (B, deg, H)
# equation (7) second term # equation (7) second term
c_tild = th.sum(f * msgs['c'], 1) c_tild = th.sum(f * nodes.mailbox['c'], 1)
return {'h_tild' : h_tild, 'c_tild' : c_tild} return {'h_tild' : h_tild, 'c_tild' : c_tild}
def apply_func(self, node): def apply_func(self, nodes):
# equation (3), (5), (6) # equation (3), (5), (6)
iou = self.W_iou(node['x']) + self.U_iou(node['h_tild']) iou = self.W_iou(nodes.data['x']) + self.U_iou(nodes.data['h_tild'])
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (7) # equation (7)
c = i * u + node['c_tild'] c = i * u + nodes.data['c_tild']
# equation (8) # equation (8)
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
...@@ -98,14 +98,15 @@ class TreeLSTM(nn.Module): ...@@ -98,14 +98,15 @@ class TreeLSTM(nn.Module):
mask = (wordid != dgl.data.SST.PAD_WORD) mask = (wordid != dgl.data.SST.PAD_WORD)
wordid = wordid * mask.long() wordid = wordid * mask.long()
embeds = self.embedding(wordid) embeds = self.embedding(wordid)
x = embeds * th.unsqueeze(mask, 1).float() g.ndata['x'] = embeds * th.unsqueeze(mask, 1).float()
if h is None: if h is None:
h = zero_initializer((n, self.h_size)) h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size)) g.ndata['h'] = h
g.ndata['h_tild'] = zero_initializer((n, self.h_size))
if c is None: if c is None:
c = zero_initializer((n, self.h_size)) c = zero_initializer((n, self.h_size))
c_tild = zero_initializer((n, self.h_size)) g.ndata['c'] = c
g.set_n_repr({'x' : x, 'h' : h, 'c' : c, 'h_tild' : h_tild, 'c_tild' : c_tild}) g.ndata['c_tild'] = zero_initializer((n, self.h_size))
# TODO(minjie): potential bottleneck # TODO(minjie): potential bottleneck
if iterator is None: if iterator is None:
g.propagate('topo') g.propagate('topo')
...@@ -113,7 +114,7 @@ class TreeLSTM(nn.Module): ...@@ -113,7 +114,7 @@ class TreeLSTM(nn.Module):
for frontier in iterator: for frontier in iterator:
g.pull(frontier) g.pull(frontier)
# compute logits # compute logits
h = g.pop_n_repr('h') h = g.ndata.pop('h')
h = self.dropout(h) h = self.dropout(h)
logits = self.linear(h) logits = self.linear(h)
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment