tree_lstm.py 4.38 KB
Newer Older
1
2
3
4
5
"""
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import itertools
6
7
8
import time

import mxnet as mx
9
10
11
import networkx as nx
import numpy as np
from mxnet import gluon
12

13
14
import dgl

15

16
17
18
19
20
21
22
23
24
25
class _TreeLSTMCellNodeFunc(gluon.HybridBlock):
    def hybrid_forward(self, F, iou, b_iou, c):
        iou = F.broadcast_add(iou, b_iou)
        i, o, u = iou.split(num_outputs=3, axis=1)
        i, o, u = i.sigmoid(), o.sigmoid(), u.tanh()
        c = i * u + c
        h = o * c.tanh()

        return h, c

26

27
28
29
30
31
32
33
34
35
36
37
38
39
class _TreeLSTMCellReduceFunc(gluon.HybridBlock):
    def __init__(self, U_iou, U_f):
        super(_TreeLSTMCellReduceFunc, self).__init__()
        self.U_iou = U_iou
        self.U_f = U_f

    def hybrid_forward(self, F, h, c):
        h_cat = h.reshape((0, -1))
        f = self.U_f(h_cat).sigmoid().reshape_like(h)
        c = (f * c).sum(axis=1)
        iou = self.U_iou(h_cat)
        return iou, c

40

41
42
43
44
class _TreeLSTMCell(gluon.HybridBlock):
    def __init__(self, h_size):
        super(_TreeLSTMCell, self).__init__()
        self._apply_node_func = _TreeLSTMCellNodeFunc()
45
46
47
        self.b_iou = self.params.get(
            "bias", shape=(1, 3 * h_size), init="zeros"
        )
48
49

    def message_func(self, edges):
50
        return {"h": edges.src["h"], "c": edges.src["c"]}
51
52

    def apply_node_func(self, nodes):
53
54
        iou = nodes.data["iou"]
        b_iou, c = self.b_iou.data(iou.context), nodes.data["c"]
55
        h, c = self._apply_node_func(iou, b_iou, c)
56
57
        return {"h": h, "c": c}

58
59
60
61
62

class TreeLSTMCell(_TreeLSTMCell):
    def __init__(self, x_size, h_size):
        super(TreeLSTMCell, self).__init__(h_size)
        self._reduce_func = _TreeLSTMCellReduceFunc(
63
64
65
            gluon.nn.Dense(3 * h_size, use_bias=False),
            gluon.nn.Dense(2 * h_size),
        )
66
67
68
        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)

    def reduce_func(self, nodes):
69
        h, c = nodes.mailbox["h"], nodes.mailbox["c"]
70
        iou, c = self._reduce_func(h, c)
71
72
        return {"iou": iou, "c": c}

73
74
75
76
77
78
79
80
81

class ChildSumTreeLSTMCell(_TreeLSTMCell):
    def __init__(self, x_size, h_size):
        super(ChildSumTreeLSTMCell, self).__init__()
        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
        self.U_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
        self.U_f = gluon.nn.Dense(h_size)

    def reduce_func(self, nodes):
82
83
84
85
86
        h_tild = nodes.mailbox["h"].sum(axis=1)
        f = self.U_f(nodes.mailbox["h"]).sigmoid()
        c = (f * nodes.mailbox["c"]).sum(axis=1)
        return {"iou": self.U_iou(h_tild), "c": c}

87
88

class TreeLSTM(gluon.nn.Block):
89
90
91
92
93
94
95
96
97
98
99
    def __init__(
        self,
        num_vocabs,
        x_size,
        h_size,
        num_classes,
        dropout,
        cell_type="nary",
        pretrained_emb=None,
        ctx=None,
    ):
100
101
102
103
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.embedding = gluon.nn.Embedding(num_vocabs, x_size)
        if pretrained_emb is not None:
104
            print("Using glove")
105
106
107
108
            self.embedding.initialize(ctx=ctx)
            self.embedding.weight.set_data(pretrained_emb)
        self.dropout = gluon.nn.Dropout(dropout)
        self.linear = gluon.nn.Dense(num_classes)
109
        cell = TreeLSTMCell if cell_type == "nary" else ChildSumTreeLSTMCell
110
        self.cell = cell(x_size, h_size)
111
        self.ctx = ctx
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.
        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.
        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
129
        g = g.to(self.ctx)
130
131
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
Da Zheng's avatar
Da Zheng committed
132
        wiou = self.cell.W_iou(self.dropout(embeds))
133
134
135
        g.ndata["iou"] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)
        g.ndata["h"] = h
        g.ndata["c"] = c
136
        # propagate
137
138
139
140
141
142
        dgl.prop_nodes_topo(
            g,
            message_func=self.cell.message_func,
            reduce_func=self.cell.reduce_func,
            apply_node_func=self.cell.apply_node_func,
        )
143
        # compute logits
144
        h = self.dropout(g.ndata.pop("h"))
145
146
        logits = self.linear(h)
        return logits