Unverified Commit 4edde000 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Fix tree lstm (#2052)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-63-129.ec2.internal>
parent c8a44c78
...@@ -27,3 +27,6 @@ DGLBACKEND=mxnet python3 train.py --gpu 0 ...@@ -27,3 +27,6 @@ DGLBACKEND=mxnet python3 train.py --gpu 0
## Speed Test ## Speed Test
See https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc . See https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc .
## Note
The code can work with MXNet 1.5.1
...@@ -96,6 +96,7 @@ class TreeLSTM(gluon.nn.Block): ...@@ -96,6 +96,7 @@ class TreeLSTM(gluon.nn.Block):
self.linear = gluon.nn.Dense(num_classes) self.linear = gluon.nn.Dense(num_classes)
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size) self.cell = cell(x_size, h_size)
self.ctx = ctx
def forward(self, batch, h, c): def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
...@@ -113,6 +114,7 @@ class TreeLSTM(gluon.nn.Block): ...@@ -113,6 +114,7 @@ class TreeLSTM(gluon.nn.Block):
The prediction of each node. The prediction of each node.
""" """
g = batch.graph g = batch.graph
g = g.to(self.ctx)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
wiou = self.cell.W_iou(self.dropout(embeds)) wiou = self.cell.W_iou(self.dropout(embeds))
......
...@@ -214,6 +214,7 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -214,6 +214,7 @@ class SSTDataset(DGLBuiltinDataset):
self._trees = load_graphs(graph_path)[0] self._trees = load_graphs(graph_path)[0]
self._vocab = load_info(vocab_path)['vocab'] self._vocab = load_info(vocab_path)['vocab']
self._pretrained_emb = None
if os.path.exists(emb_path): if os.path.exists(emb_path):
self._pretrained_emb = load_info(emb_path)['embed'] self._pretrained_emb = load_info(emb_path)['embed']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment