Unverified Commit 4edde000 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Fix tree lstm (#2052)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-63-129.ec2.internal>
parent c8a44c78
......@@ -27,3 +27,6 @@ DGLBACKEND=mxnet python3 train.py --gpu 0
## Speed Test
See https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc .
## Note
The code can work with MXNet 1.5.1
......@@ -96,6 +96,7 @@ class TreeLSTM(gluon.nn.Block):
self.linear = gluon.nn.Dense(num_classes)
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size)
self.ctx = ctx
def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch.
......@@ -113,6 +114,7 @@ class TreeLSTM(gluon.nn.Block):
The prediction of each node.
"""
g = batch.graph
g = g.to(self.ctx)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
wiou = self.cell.W_iou(self.dropout(embeds))
......
......@@ -214,6 +214,7 @@ class SSTDataset(DGLBuiltinDataset):
self._trees = load_graphs(graph_path)[0]
self._vocab = load_info(vocab_path)['vocab']
self._pretrained_emb = None
if os.path.exists(emb_path):
self._pretrained_emb = load_info(emb_path)['embed']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment