tree.py 9.08 KB
Newer Older
1
2
3
4
5
6
"""Tree-structured data.
Including:
    - Stanford Sentiment Treebank
"""
from __future__ import absolute_import

7
8
import os

Xiangkun Hu's avatar
Xiangkun Hu committed
9
from collections import OrderedDict
10

11
12
import networkx as nx

13
import numpy as np
14
15

from .. import backend as F
16
from ..convert import from_networkx
17

18
19
20
21
22
23
24
25
26
27
28
from .dgl_dataset import DGLBuiltinDataset
from .utils import (
    _get_dgl_url,
    deprecate_property,
    load_graphs,
    load_info,
    save_graphs,
    save_info,
)

__all__ = ["SST", "SSTDataset"]
29

30

Xiangkun Hu's avatar
Xiangkun Hu committed
31
32
class SSTDataset(DGLBuiltinDataset):
    r"""Stanford Sentiment Treebank dataset.
33

34
    Each sample is the constituency tree of a sentence. The leaf nodes
Mufei Li's avatar
Mufei Li committed
35
36
    represent words. The word is a int value stored in the ``x`` feature field.
    The non-leaf node has a special value ``PAD_WORD`` in the ``x`` field.
37
38
    Each node also has a sentiment annotation: 5 classes (very negative,
    negative, neutral, positive and very positive). The sentiment label is a
Mufei Li's avatar
Mufei Li committed
39
    int value stored in the ``y`` feature field.
40
    Official site: `<http://nlp.stanford.edu/sentiment/index.html>`_
41

42
43
44
45
46
47
    Statistics:

    - Train examples: 8,544
    - Dev examples: 1,101
    - Test examples: 2,210
    - Number of classes for each node: 5
48

49
50
51
    Parameters
    ----------
    mode : str, optional
Xiangkun Hu's avatar
Xiangkun Hu committed
52
53
54
55
56
        Should be one of ['train', 'dev', 'test', 'tiny']
        Default: train
    glove_embed_file : str, optional
        The path to pretrained glove embedding file.
        Default: None
57
    vocab_file : str, optional
Xiangkun Hu's avatar
Xiangkun Hu committed
58
59
60
61
62
63
64
        Optional vocabulary file. If not given, the default vacabulary file is used.
        Default: None
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
65
    verbose : bool
Xiangkun Hu's avatar
Xiangkun Hu committed
66
        Whether to print out progress information. Default: True.
67
68
69
70
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
Xiangkun Hu's avatar
Xiangkun Hu committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    Attributes
    ----------
    vocab : OrderedDict
        Vocabulary of the dataset
    num_classes : int
        Number of classes for each node
    pretrained_emb: Tensor
        Pretrained glove embedding with respect the vocabulary.
    vocab_size : int
        The size of the vocabulary

    Notes
    -----
    All the samples will be loaded and preprocessed in the memory first.

    Examples
    --------
    >>> # get dataset
    >>> train_data = SSTDataset()
    >>> dev_data = SSTDataset(mode='dev')
    >>> test_data = SSTDataset(mode='test')
    >>> tiny_data = SSTDataset(mode='tiny')
    >>>
    >>> len(train_data)
    8544
    >>> train_data.num_classes
    5
    >>> glove_embed = train_data.pretrained_emb
    >>> train_data.vocab_size
    19536
    >>> train_data[0]
103
104
105
    Graph(num_nodes=71, num_edges=70,
      ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})
Xiangkun Hu's avatar
Xiangkun Hu committed
106
107
108
109
110
    >>> for tree in train_data:
    ...     input_ids = tree.ndata['x']
    ...     labels = tree.ndata['y']
    ...     mask = tree.ndata['mask']
    ...     # your code here
111
    """
Xiangkun Hu's avatar
Xiangkun Hu committed
112
113
114
115

    PAD_WORD = -1  # special pad word id
    UNK_WORD = -1  # out-of-vocabulary word id

116
117
118
119
120
121
122
123
124
125
126
127
128
    def __init__(
        self,
        mode="train",
        glove_embed_file=None,
        vocab_file=None,
        raw_dir=None,
        force_reload=False,
        verbose=False,
        transform=None,
    ):
        assert mode in ["train", "dev", "test", "tiny"]
        _url = _get_dgl_url("dataset/sst.zip")
        self._glove_embed_file = glove_embed_file if mode == "train" else None
129
        self.mode = mode
Xiangkun Hu's avatar
Xiangkun Hu committed
130
        self._vocab_file = vocab_file
131
132
133
134
135
136
137
138
        super(SSTDataset, self).__init__(
            name="sst",
            url=_url,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )
Xiangkun Hu's avatar
Xiangkun Hu committed
139
140

    def process(self):
141
        from nltk.corpus.reader import BracketParseCorpusReader
142

143
        # load vocab file
Xiangkun Hu's avatar
Xiangkun Hu committed
144
        self._vocab = OrderedDict()
145
146
147
148
149
150
        vocab_file = (
            self._vocab_file
            if self._vocab_file is not None
            else os.path.join(self.raw_path, "vocab.txt")
        )
        with open(vocab_file, encoding="utf-8") as vf:
151
152
            for line in vf.readlines():
                line = line.strip()
Xiangkun Hu's avatar
Xiangkun Hu committed
153
                self._vocab[line] = len(self._vocab)
154
155

        # filter glove
156
157
158
        if self._glove_embed_file is not None and os.path.exists(
            self._glove_embed_file
        ):
159
            glove_emb = {}
160
            with open(self._glove_embed_file, "r", encoding="utf-8") as pf:
161
                for line in pf.readlines():
162
                    sp = line.split(" ")
Xiangkun Hu's avatar
Xiangkun Hu committed
163
                    if sp[0].lower() in self._vocab:
164
165
166
167
                        glove_emb[sp[0].lower()] = np.asarray(
                            [float(x) for x in sp[1:]]
                        )
        files = ["{}.txt".format(self.mode)]
Xiangkun Hu's avatar
Xiangkun Hu committed
168
        corpus = BracketParseCorpusReader(self.raw_path, files)
169
170
        sents = corpus.parsed_sents(files[0])

Xiangkun Hu's avatar
Xiangkun Hu committed
171
        # initialize with glove
172
173
        pretrained_emb = []
        fail_cnt = 0
Xiangkun Hu's avatar
Xiangkun Hu committed
174
        for line in self._vocab.keys():
175
176
177
            if self._glove_embed_file is not None and os.path.exists(
                self._glove_embed_file
            ):
178
179
                if not line.lower() in glove_emb:
                    fail_cnt += 1
180
181
182
183
184
                pretrained_emb.append(
                    glove_emb.get(
                        line.lower(), np.random.uniform(-0.05, 0.05, 300)
                    )
                )
185

Xiangkun Hu's avatar
Xiangkun Hu committed
186
        self._pretrained_emb = None
187
188
189
        if self._glove_embed_file is not None and os.path.exists(
            self._glove_embed_file
        ):
Xiangkun Hu's avatar
Xiangkun Hu committed
190
            self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
191
192
193
194
195
            print(
                "Miss word in GloVe {0:.4f}".format(
                    1.0 * fail_cnt / len(self._pretrained_emb)
                )
            )
196
        # build trees
Xiangkun Hu's avatar
Xiangkun Hu committed
197
        self._trees = []
198
        for sent in sents:
Xiangkun Hu's avatar
Xiangkun Hu committed
199
            self._trees.append(self._build_tree(sent))
200

201
202
    def _build_tree(self, root):
        g = nx.DiGraph()
Xiangkun Hu's avatar
Xiangkun Hu committed
203

204
205
206
        def _rec_build(nid, node):
            for child in node:
                cid = g.number_of_nodes()
207
                if isinstance(child[0], str) or isinstance(child[0], bytes):
208
                    # leaf node
209
                    word = self.vocab.get(child[0].lower(), self.UNK_WORD)
210
                    g.add_node(cid, x=word, y=int(child.label()), mask=1)
211
                else:
212
213
214
                    g.add_node(
                        cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0
                    )
215
216
                    _rec_build(cid, child)
                g.add_edge(cid, nid)
Xiangkun Hu's avatar
Xiangkun Hu committed
217

218
        # add root
Xiangkun Hu's avatar
Xiangkun Hu committed
219
        g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0)
220
        _rec_build(0, root)
221
        ret = from_networkx(g, node_attrs=["x", "y", "mask"])
Minjie Wang's avatar
Minjie Wang committed
222
        return ret
223

224
225
226
227
228
229
230
231
    @property
    def graph_path(self):
        return os.path.join(self.save_path, self.mode + "_dgl_graph.bin")

    @property
    def vocab_path(self):
        return os.path.join(self.save_path, "vocab.pkl")

Xiangkun Hu's avatar
Xiangkun Hu committed
232
    def has_cache(self):
233
234
235
        return os.path.exists(self.graph_path) and os.path.exists(
            self.vocab_path
        )
Xiangkun Hu's avatar
Xiangkun Hu committed
236
237

    def save(self):
238
239
        save_graphs(self.graph_path, self._trees)
        save_info(self.vocab_path, {"vocab": self.vocab})
240
        if self.pretrained_emb:
241
242
            emb_path = os.path.join(self.save_path, "emb.pkl")
            save_info(emb_path, {"embed": self.pretrained_emb})
Xiangkun Hu's avatar
Xiangkun Hu committed
243
244

    def load(self):
245
        emb_path = os.path.join(self.save_path, "emb.pkl")
246

247
248
        self._trees = load_graphs(self.graph_path)[0]
        self._vocab = load_info(self.vocab_path)["vocab"]
249
        self._pretrained_emb = None
250
        if os.path.exists(emb_path):
251
            self._pretrained_emb = load_info(emb_path)["embed"]
Xiangkun Hu's avatar
Xiangkun Hu committed
252
253
254

    @property
    def vocab(self):
255
        r"""Vocabulary
Xiangkun Hu's avatar
Xiangkun Hu committed
256
257
258
259
260
261
262
263
264
265
266
267

        Returns
        -------
        OrderedDict
        """
        return self._vocab

    @property
    def pretrained_emb(self):
        r"""Pre-trained word embedding, if given."""
        return self._pretrained_emb

268
    def __getitem__(self, idx):
269
        r"""Get graph by index
Mufei Li's avatar
Mufei Li committed
270
271
272
273
274
275
276

        Parameters
        ----------
        idx : int

        Returns
        -------
277
278
279
280
281
282
283
        :class:`dgl.DGLGraph`

            graph structure, word id for each node, node labels and masks.

            - ``ndata['x']``: word id of the node
            - ``ndata['y']:`` label of the node
            - ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
Mufei Li's avatar
Mufei Li committed
284
        """
285
286
287
288
        if self._transform is None:
            return self._trees[idx]
        else:
            return self._transform(self._trees[idx])
289
290

    def __len__(self):
Xiangkun Hu's avatar
Xiangkun Hu committed
291
292
        r"""Number of graphs in the dataset."""
        return len(self._trees)
293

Xiangkun Hu's avatar
Xiangkun Hu committed
294
295
296
297
298
299
300
301
302
303
304
305
    @property
    def vocab_size(self):
        r"""Vocabulary size."""
        return len(self._vocab)

    @property
    def num_classes(self):
        r"""Number of classes for each node."""
        return 5


SST = SSTDataset