tree_lstm.py 4.38 KB
Newer Older
1
2
3
4
5
"""
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import itertools
6
7
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
8
9
import dgl

10
import mxnet as mx
11
12
13
import networkx as nx
import numpy as np
from mxnet import gluon
14
15


16
17
18
19
20
21
22
23
24
25
class _TreeLSTMCellNodeFunc(gluon.HybridBlock):
    def hybrid_forward(self, F, iou, b_iou, c):
        iou = F.broadcast_add(iou, b_iou)
        i, o, u = iou.split(num_outputs=3, axis=1)
        i, o, u = i.sigmoid(), o.sigmoid(), u.tanh()
        c = i * u + c
        h = o * c.tanh()

        return h, c

26

27
28
29
30
31
32
33
34
35
36
37
38
39
class _TreeLSTMCellReduceFunc(gluon.HybridBlock):
    def __init__(self, U_iou, U_f):
        super(_TreeLSTMCellReduceFunc, self).__init__()
        self.U_iou = U_iou
        self.U_f = U_f

    def hybrid_forward(self, F, h, c):
        h_cat = h.reshape((0, -1))
        f = self.U_f(h_cat).sigmoid().reshape_like(h)
        c = (f * c).sum(axis=1)
        iou = self.U_iou(h_cat)
        return iou, c

40

41
42
43
44
class _TreeLSTMCell(gluon.HybridBlock):
    def __init__(self, h_size):
        super(_TreeLSTMCell, self).__init__()
        self._apply_node_func = _TreeLSTMCellNodeFunc()
45
46
47
        self.b_iou = self.params.get(
            "bias", shape=(1, 3 * h_size), init="zeros"
        )
48
49

    def message_func(self, edges):
50
        return {"h": edges.src["h"], "c": edges.src["c"]}
51
52

    def apply_node_func(self, nodes):
53
54
        iou = nodes.data["iou"]
        b_iou, c = self.b_iou.data(iou.context), nodes.data["c"]
55
        h, c = self._apply_node_func(iou, b_iou, c)
56
57
        return {"h": h, "c": c}

58
59
60
61
62

class TreeLSTMCell(_TreeLSTMCell):
    def __init__(self, x_size, h_size):
        super(TreeLSTMCell, self).__init__(h_size)
        self._reduce_func = _TreeLSTMCellReduceFunc(
63
64
65
            gluon.nn.Dense(3 * h_size, use_bias=False),
            gluon.nn.Dense(2 * h_size),
        )
66
67
68
        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)

    def reduce_func(self, nodes):
69
        h, c = nodes.mailbox["h"], nodes.mailbox["c"]
70
        iou, c = self._reduce_func(h, c)
71
72
        return {"iou": iou, "c": c}

73
74
75
76
77
78
79
80
81

class ChildSumTreeLSTMCell(_TreeLSTMCell):
    def __init__(self, x_size, h_size):
        super(ChildSumTreeLSTMCell, self).__init__()
        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
        self.U_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
        self.U_f = gluon.nn.Dense(h_size)

    def reduce_func(self, nodes):
82
83
84
85
86
        h_tild = nodes.mailbox["h"].sum(axis=1)
        f = self.U_f(nodes.mailbox["h"]).sigmoid()
        c = (f * nodes.mailbox["c"]).sum(axis=1)
        return {"iou": self.U_iou(h_tild), "c": c}

87
88

class TreeLSTM(gluon.nn.Block):
89
90
91
92
93
94
95
96
97
98
99
    def __init__(
        self,
        num_vocabs,
        x_size,
        h_size,
        num_classes,
        dropout,
        cell_type="nary",
        pretrained_emb=None,
        ctx=None,
    ):
100
101
102
103
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.embedding = gluon.nn.Embedding(num_vocabs, x_size)
        if pretrained_emb is not None:
104
            print("Using glove")
105
106
107
108
            self.embedding.initialize(ctx=ctx)
            self.embedding.weight.set_data(pretrained_emb)
        self.dropout = gluon.nn.Dropout(dropout)
        self.linear = gluon.nn.Dense(num_classes)
109
        cell = TreeLSTMCell if cell_type == "nary" else ChildSumTreeLSTMCell
110
        self.cell = cell(x_size, h_size)
111
        self.ctx = ctx
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.
        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.
        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
129
        g = g.to(self.ctx)
130
131
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
Da Zheng's avatar
Da Zheng committed
132
        wiou = self.cell.W_iou(self.dropout(embeds))
133
134
135
        g.ndata["iou"] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)
        g.ndata["h"] = h
        g.ndata["c"] = c
136
        # propagate
137
138
139
140
141
142
        dgl.prop_nodes_topo(
            g,
            message_func=self.cell.message_func,
            reduce_func=self.cell.reduce_func,
            apply_node_func=self.cell.apply_node_func,
        )
143
        # compute logits
144
        h = self.dropout(g.ndata.pop("h"))
145
146
        logits = self.linear(h)
        return logits