tree_lstm.py 4.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import time
import itertools
import networkx as nx
import numpy as np
import mxnet as mx
from mxnet import gluon
import dgl

class _TreeLSTMCellNodeFunc(gluon.HybridBlock):
    def hybrid_forward(self, F, iou, b_iou, c):
        iou = F.broadcast_add(iou, b_iou)
        i, o, u = iou.split(num_outputs=3, axis=1)
        i, o, u = i.sigmoid(), o.sigmoid(), u.tanh()
        c = i * u + c
        h = o * c.tanh()

        return h, c

class _TreeLSTMCellReduceFunc(gluon.HybridBlock):
    def __init__(self, U_iou, U_f):
        super(_TreeLSTMCellReduceFunc, self).__init__()
        self.U_iou = U_iou
        self.U_f = U_f

    def hybrid_forward(self, F, h, c):
        h_cat = h.reshape((0, -1))
        f = self.U_f(h_cat).sigmoid().reshape_like(h)
        c = (f * c).sum(axis=1)
        iou = self.U_iou(h_cat)
        return iou, c

class _TreeLSTMCell(gluon.HybridBlock):
    def __init__(self, h_size):
        super(_TreeLSTMCell, self).__init__()
        self._apply_node_func = _TreeLSTMCellNodeFunc()
        self.b_iou = self.params.get('bias', shape=(1, 3 * h_size),
                                     init='zeros')

    def message_func(self, edges):
        return {'h': edges.src['h'], 'c': edges.src['c']}

    def apply_node_func(self, nodes):
        iou = nodes.data['iou']
        b_iou, c = self.b_iou.data(iou.context), nodes.data['c']
        h, c = self._apply_node_func(iou, b_iou, c)
        return {'h' : h, 'c' : c}

class TreeLSTMCell(_TreeLSTMCell):
    def __init__(self, x_size, h_size):
        super(TreeLSTMCell, self).__init__(h_size)
        self._reduce_func = _TreeLSTMCellReduceFunc(
                gluon.nn.Dense(3 * h_size, use_bias=False),
                gluon.nn.Dense(2 * h_size))
        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)

    def reduce_func(self, nodes):
        h, c = nodes.mailbox['h'], nodes.mailbox['c']
        iou, c = self._reduce_func(h, c)
        return {'iou': iou, 'c': c}

class ChildSumTreeLSTMCell(_TreeLSTMCell):
    def __init__(self, x_size, h_size):
        super(ChildSumTreeLSTMCell, self).__init__()
        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
        self.U_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
        self.U_f = gluon.nn.Dense(h_size)

    def reduce_func(self, nodes):
        h_tild = nodes.mailbox['h'].sum(axis=1)
        f = self.U_f(nodes.mailbox['h']).sigmoid()
        c = (f * nodes.mailbox['c']).sum(axis=1)
        return {'iou': self.U_iou(h_tild), 'c': c}

class TreeLSTM(gluon.nn.Block):
    def __init__(self,
                 num_vocabs,
                 x_size,
                 h_size,
                 num_classes,
                 dropout,
                 cell_type='nary',
                 pretrained_emb=None,
                 ctx=None):
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.embedding = gluon.nn.Embedding(num_vocabs, x_size)
        if pretrained_emb is not None:
            print('Using glove')
            self.embedding.initialize(ctx=ctx)
            self.embedding.weight.set_data(pretrained_emb)
        self.dropout = gluon.nn.Dropout(dropout)
        self.linear = gluon.nn.Dense(num_classes)
        cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
        self.cell = cell(x_size, h_size)

    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.
        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.
        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
Da Zheng's avatar
Da Zheng committed
118
119
        wiou = self.cell.W_iou(self.dropout(embeds))
        g.ndata['iou'] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)
120
121
122
        g.ndata['h'] = h
        g.ndata['c'] = c
        # propagate
Xiangkun Hu's avatar
Xiangkun Hu committed
123
124
125
126
        dgl.prop_nodes_topo(g,
                            message_func=self.cell.message_func,
                            reduce_func=self.cell.reduce_func,
                            apply_node_func=self.cell.apply_node_func)
127
128
129
130
        # compute logits
        h = self.dropout(g.ndata.pop('h'))
        logits = self.linear(h)
        return logits