tree.py 9.01 KB
Newer Older
1
2
3
4
5
6
"""Tree-structured data.
Including:
    - Stanford Sentiment Treebank
"""
from __future__ import absolute_import

Xiangkun Hu's avatar
Xiangkun Hu committed
7
from collections import OrderedDict
8
9
import networkx as nx

10
11
import numpy as np
import os
12

Xiangkun Hu's avatar
Xiangkun Hu committed
13
from .dgl_dataset import DGLBuiltinDataset
14
from .. import backend as F
Xiangkun Hu's avatar
Xiangkun Hu committed
15
16
from .utils import _get_dgl_url, save_graphs, save_info, load_graphs, \
    load_info, deprecate_property
17
from ..convert import from_networkx
18

Xiangkun Hu's avatar
Xiangkun Hu committed
19
__all__ = ['SST', 'SSTDataset']
20

21

Xiangkun Hu's avatar
Xiangkun Hu committed
22
23
class SSTDataset(DGLBuiltinDataset):
    r"""Stanford Sentiment Treebank dataset.
24

25
    Each sample is the constituency tree of a sentence. The leaf nodes
Mufei Li's avatar
Mufei Li committed
26
27
    represent words. The word is a int value stored in the ``x`` feature field.
    The non-leaf node has a special value ``PAD_WORD`` in the ``x`` field.
28
29
    Each node also has a sentiment annotation: 5 classes (very negative,
    negative, neutral, positive and very positive). The sentiment label is a
Mufei Li's avatar
Mufei Li committed
30
    int value stored in the ``y`` feature field.
31
    Official site: `<http://nlp.stanford.edu/sentiment/index.html>`_
32

33
34
35
36
37
38
    Statistics:

    - Train examples: 8,544
    - Dev examples: 1,101
    - Test examples: 2,210
    - Number of classes for each node: 5
39

40
41
42
    Parameters
    ----------
    mode : str, optional
Xiangkun Hu's avatar
Xiangkun Hu committed
43
44
45
46
47
        Should be one of ['train', 'dev', 'test', 'tiny']
        Default: train
    glove_embed_file : str, optional
        The path to pretrained glove embedding file.
        Default: None
48
    vocab_file : str, optional
Xiangkun Hu's avatar
Xiangkun Hu committed
49
50
51
52
53
54
55
        Optional vocabulary file. If not given, the default vacabulary file is used.
        Default: None
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
56
    verbose : bool
Xiangkun Hu's avatar
Xiangkun Hu committed
57
        Whether to print out progress information. Default: True.
58
59
60
61
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
Xiangkun Hu's avatar
Xiangkun Hu committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    Attributes
    ----------
    vocab : OrderedDict
        Vocabulary of the dataset
    num_classes : int
        Number of classes for each node
    pretrained_emb: Tensor
        Pretrained glove embedding with respect the vocabulary.
    vocab_size : int
        The size of the vocabulary

    Notes
    -----
    All the samples will be loaded and preprocessed in the memory first.

    Examples
    --------
    >>> # get dataset
    >>> train_data = SSTDataset()
    >>> dev_data = SSTDataset(mode='dev')
    >>> test_data = SSTDataset(mode='test')
    >>> tiny_data = SSTDataset(mode='tiny')
    >>>
    >>> len(train_data)
    8544
    >>> train_data.num_classes
    5
    >>> glove_embed = train_data.pretrained_emb
    >>> train_data.vocab_size
    19536
    >>> train_data[0]
94
95
96
    Graph(num_nodes=71, num_edges=70,
      ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})
Xiangkun Hu's avatar
Xiangkun Hu committed
97
98
99
100
101
    >>> for tree in train_data:
    ...     input_ids = tree.ndata['x']
    ...     labels = tree.ndata['y']
    ...     mask = tree.ndata['mask']
    ...     # your code here
102
    """
Xiangkun Hu's avatar
Xiangkun Hu committed
103
104
105
106
107
108
109
110
111
112

    PAD_WORD = -1  # special pad word id
    UNK_WORD = -1  # out-of-vocabulary word id

    def __init__(self,
                 mode='train',
                 glove_embed_file=None,
                 vocab_file=None,
                 raw_dir=None,
                 force_reload=False,
113
114
                 verbose=False,
                 transform=None):
Xiangkun Hu's avatar
Xiangkun Hu committed
115
116
117
        assert mode in ['train', 'dev', 'test', 'tiny']
        _url = _get_dgl_url('dataset/sst.zip')
        self._glove_embed_file = glove_embed_file if mode == 'train' else None
118
        self.mode = mode
Xiangkun Hu's avatar
Xiangkun Hu committed
119
120
121
122
123
        self._vocab_file = vocab_file
        super(SSTDataset, self).__init__(name='sst',
                                         url=_url,
                                         raw_dir=raw_dir,
                                         force_reload=force_reload,
124
125
                                         verbose=verbose,
                                         transform=transform)
Xiangkun Hu's avatar
Xiangkun Hu committed
126
127

    def process(self):
128
        from nltk.corpus.reader import BracketParseCorpusReader
129
        # load vocab file
Xiangkun Hu's avatar
Xiangkun Hu committed
130
131
132
        self._vocab = OrderedDict()
        vocab_file = self._vocab_file if self._vocab_file is not None else os.path.join(self.raw_path, 'vocab.txt')
        with open(vocab_file, encoding='utf-8') as vf:
133
134
            for line in vf.readlines():
                line = line.strip()
Xiangkun Hu's avatar
Xiangkun Hu committed
135
                self._vocab[line] = len(self._vocab)
136
137

        # filter glove
Xiangkun Hu's avatar
Xiangkun Hu committed
138
        if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
139
            glove_emb = {}
Xiangkun Hu's avatar
Xiangkun Hu committed
140
            with open(self._glove_embed_file, 'r', encoding='utf-8') as pf:
141
142
                for line in pf.readlines():
                    sp = line.split(' ')
Xiangkun Hu's avatar
Xiangkun Hu committed
143
                    if sp[0].lower() in self._vocab:
144
                        glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]])
145
        files = ['{}.txt'.format(self.mode)]
Xiangkun Hu's avatar
Xiangkun Hu committed
146
        corpus = BracketParseCorpusReader(self.raw_path, files)
147
148
        sents = corpus.parsed_sents(files[0])

Xiangkun Hu's avatar
Xiangkun Hu committed
149
        # initialize with glove
150
151
        pretrained_emb = []
        fail_cnt = 0
Xiangkun Hu's avatar
Xiangkun Hu committed
152
153
        for line in self._vocab.keys():
            if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
154
155
156
157
                if not line.lower() in glove_emb:
                    fail_cnt += 1
                pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300)))

Xiangkun Hu's avatar
Xiangkun Hu committed
158
159
160
161
        self._pretrained_emb = None
        if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
            self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
            print('Miss word in GloVe {0:.4f}'.format(1.0 * fail_cnt / len(self._pretrained_emb)))
162
        # build trees
Xiangkun Hu's avatar
Xiangkun Hu committed
163
        self._trees = []
164
        for sent in sents:
Xiangkun Hu's avatar
Xiangkun Hu committed
165
            self._trees.append(self._build_tree(sent))
166

167
168
    def _build_tree(self, root):
        g = nx.DiGraph()
Xiangkun Hu's avatar
Xiangkun Hu committed
169

170
171
172
        def _rec_build(nid, node):
            for child in node:
                cid = g.number_of_nodes()
173
                if isinstance(child[0], str) or isinstance(child[0], bytes):
174
                    # leaf node
175
                    word = self.vocab.get(child[0].lower(), self.UNK_WORD)
176
                    g.add_node(cid, x=word, y=int(child.label()), mask=1)
177
                else:
Xiangkun Hu's avatar
Xiangkun Hu committed
178
                    g.add_node(cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0)
179
180
                    _rec_build(cid, child)
                g.add_edge(cid, nid)
Xiangkun Hu's avatar
Xiangkun Hu committed
181

182
        # add root
Xiangkun Hu's avatar
Xiangkun Hu committed
183
        g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0)
184
        _rec_build(0, root)
185
        ret = from_networkx(g, node_attrs=['x', 'y', 'mask'])
Minjie Wang's avatar
Minjie Wang committed
186
        return ret
187

Xiangkun Hu's avatar
Xiangkun Hu committed
188
189
    def has_cache(self):
        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
190
191
        vocab_path = os.path.join(self.save_path, 'vocab.pkl')
        return os.path.exists(graph_path) and os.path.exists(vocab_path)
Xiangkun Hu's avatar
Xiangkun Hu committed
192
193
194
195

    def save(self):
        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
        save_graphs(graph_path, self._trees)
196
197
198
199
200
        vocab_path = os.path.join(self.save_path, 'vocab.pkl')
        save_info(vocab_path, {'vocab': self.vocab})
        if self.pretrained_emb:
            emb_path = os.path.join(self.save_path, 'emb.pkl')
            save_info(emb_path, {'embed': self.pretrained_emb})
Xiangkun Hu's avatar
Xiangkun Hu committed
201
202
203

    def load(self):
        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
204
205
206
        vocab_path = os.path.join(self.save_path, 'vocab.pkl')
        emb_path = os.path.join(self.save_path, 'emb.pkl')

Xiangkun Hu's avatar
Xiangkun Hu committed
207
        self._trees = load_graphs(graph_path)[0]
208
        self._vocab = load_info(vocab_path)['vocab']
209
        self._pretrained_emb = None
210
211
        if os.path.exists(emb_path):
            self._pretrained_emb = load_info(emb_path)['embed']
Xiangkun Hu's avatar
Xiangkun Hu committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

    @property
    def vocab(self):
        r""" Vocabulary

        Returns
        -------
        OrderedDict
        """
        return self._vocab

    @property
    def pretrained_emb(self):
        r"""Pre-trained word embedding, if given."""
        return self._pretrained_emb

228
    def __getitem__(self, idx):
Xiangkun Hu's avatar
Xiangkun Hu committed
229
        r""" Get graph by index
Mufei Li's avatar
Mufei Li committed
230
231
232
233
234
235
236

        Parameters
        ----------
        idx : int

        Returns
        -------
237
238
239
240
241
242
243
        :class:`dgl.DGLGraph`

            graph structure, word id for each node, node labels and masks.

            - ``ndata['x']``: word id of the node
            - ``ndata['y']:`` label of the node
            - ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
Mufei Li's avatar
Mufei Li committed
244
        """
245
246
247
248
        if self._transform is None:
            return self._trees[idx]
        else:
            return self._transform(self._trees[idx])
249
250

    def __len__(self):
Xiangkun Hu's avatar
Xiangkun Hu committed
251
252
        r"""Number of graphs in the dataset."""
        return len(self._trees)
253

Xiangkun Hu's avatar
Xiangkun Hu committed
254
255
256
257
258
259
260
261
262
263
264
265
    @property
    def vocab_size(self):
        r"""Vocabulary size."""
        return len(self._vocab)

    @property
    def num_classes(self):
        r"""Number of classes for each node."""
        return 5


SST = SSTDataset