"tests/test_dictionary.py" did not exist on "7bcb487aad8504043d13c9b869d555aa565a46c7"
tree_lstm.py 3.95 KB
Newer Older
1
2
3
4
5
"""
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import itertools
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
8
import time

import dgl
9
import networkx as nx
10
import numpy as np
11
12
13
import torch as th
import torch.nn as nn
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
14

15

16
class TreeLSTMCell(nn.Module):
17
    def __init__(self, x_size, h_size):
18
        super(TreeLSTMCell, self).__init__()
19
20
21
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
        self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
22
        self.U_f = nn.Linear(2 * h_size, 2 * h_size)
23

24
    def message_func(self, edges):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
25
        return {"h": edges.src["h"], "c": edges.src["c"]}
26

27
    def reduce_func(self, nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
28
29
30
31
        h_cat = nodes.mailbox["h"].view(nodes.mailbox["h"].size(0), -1)
        f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox["h"].size())
        c = th.sum(f * nodes.mailbox["c"], 1)
        return {"iou": self.U_iou(h_cat), "c": c}
32
33

    def apply_node_func(self, nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
34
        iou = nodes.data["iou"] + self.b_iou
35
36
        i, o, u = th.chunk(iou, 3, 1)
        i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
37
        c = i * u + nodes.data["c"]
38
        h = o * th.tanh(c)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
39
40
        return {"h": h, "c": c}

41

42
43
44
class ChildSumTreeLSTMCell(nn.Module):
    def __init__(self, x_size, h_size):
        super(ChildSumTreeLSTMCell, self).__init__()
45
46
47
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
        self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
48
49
50
        self.U_f = nn.Linear(h_size, h_size)

    def message_func(self, edges):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
51
        return {"h": edges.src["h"], "c": edges.src["c"]}
52
53

    def reduce_func(self, nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
54
55
56
57
        h_tild = th.sum(nodes.mailbox["h"], 1)
        f = th.sigmoid(self.U_f(nodes.mailbox["h"]))
        c = th.sum(f * nodes.mailbox["c"], 1)
        return {"iou": self.U_iou(h_tild), "c": c}
58
59

    def apply_node_func(self, nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
60
        iou = nodes.data["iou"] + self.b_iou
61
62
        i, o, u = th.chunk(iou, 3, 1)
        i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
63
        c = i * u + nodes.data["c"]
64
        h = o * th.tanh(c)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
65
66
        return {"h": h, "c": c}

67

68
class TreeLSTM(nn.Module):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
69
70
71
72
73
74
75
76
77
78
    def __init__(
        self,
        num_vocabs,
        x_size,
        h_size,
        num_classes,
        dropout,
        cell_type="nary",
        pretrained_emb=None,
    ):
79
80
81
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.embedding = nn.Embedding(num_vocabs, x_size)
82
        if pretrained_emb is not None:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
83
            print("Using glove")
84
85
            self.embedding.weight.data.copy_(pretrained_emb)
            self.embedding.weight.requires_grad = True
86
87
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(h_size, num_classes)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88
        cell = TreeLSTMCell if cell_type == "nary" else ChildSumTreeLSTMCell
89
        self.cell = cell(x_size, h_size)
90

91
    def forward(self, batch, g, h, c):
92
        """Compute tree-lstm prediction given a batch.
93
94
        Parameters
        ----------
95
96
        batch : dgl.data.SSTBatch
            The data batch.
97
98
        g : dgl.DGLGraph
            Tree for computation.
99
        h : Tensor
100
            Initial hidden state.
101
        c : Tensor
102
103
104
105
106
            Initial cell state.
        Returns
        -------
        logits : Tensor
            The prediction of each node.
107
        """
108
        # feed embedding
109
        embeds = self.embedding(batch.wordid * batch.mask)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
110
111
112
113
114
        g.ndata["iou"] = self.cell.W_iou(
            self.dropout(embeds)
        ) * batch.mask.float().unsqueeze(-1)
        g.ndata["h"] = h
        g.ndata["c"] = c
115
        # propagate
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
116
117
118
119
120
121
        dgl.prop_nodes_topo(
            g,
            self.cell.message_func,
            self.cell.reduce_func,
            apply_node_func=self.cell.apply_node_func,
        )
122
        # compute logits
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
123
        h = self.dropout(g.ndata.pop("h"))
124
125
        logits = self.linear(h)
        return logits