tree.py 9.25 KB
Newer Older
1
2
3
4
5
6
"""Tree-structured data.
Including:
    - Stanford Sentiment Treebank
"""
from __future__ import absolute_import

Xiangkun Hu's avatar
Xiangkun Hu committed
7
from collections import OrderedDict
8
9
import networkx as nx

10
11
import numpy as np
import os
12

Xiangkun Hu's avatar
Xiangkun Hu committed
13
from .dgl_dataset import DGLBuiltinDataset
14
from .. import backend as F
Xiangkun Hu's avatar
Xiangkun Hu committed
15
16
from .utils import _get_dgl_url, save_graphs, save_info, load_graphs, \
    load_info, deprecate_property
17
from ..convert import from_networkx
18

Xiangkun Hu's avatar
Xiangkun Hu committed
19
__all__ = ['SST', 'SSTDataset']
20

21

Xiangkun Hu's avatar
Xiangkun Hu committed
22
23
class SSTDataset(DGLBuiltinDataset):
    r"""Stanford Sentiment Treebank dataset.
24

Xiangkun Hu's avatar
Xiangkun Hu committed
25
    .. deprecated:: 0.5.0
26
27
28
        
        - ``trees`` is deprecated, it is replaced by:

Xiangkun Hu's avatar
Xiangkun Hu committed
29
30
31
            >>> dataset = SSTDataset()
            >>> for tree in dataset:
            ....    # your code here
32
33

        - ``num_vocabs`` is deprecated, it is replaced by ``vocab_size``.
34
35

    Each sample is the constituency tree of a sentence. The leaf nodes
Mufei Li's avatar
Mufei Li committed
36
37
    represent words. The word is a int value stored in the ``x`` feature field.
    The non-leaf node has a special value ``PAD_WORD`` in the ``x`` field.
38
39
    Each node also has a sentiment annotation: 5 classes (very negative,
    negative, neutral, positive and very positive). The sentiment label is a
Mufei Li's avatar
Mufei Li committed
40
    int value stored in the ``y`` feature field.
41
    Official site: `<http://nlp.stanford.edu/sentiment/index.html>`_
42

43
44
45
46
47
48
    Statistics:

    - Train examples: 8,544
    - Dev examples: 1,101
    - Test examples: 2,210
    - Number of classes for each node: 5
49

50
51
52
    Parameters
    ----------
    mode : str, optional
Xiangkun Hu's avatar
Xiangkun Hu committed
53
54
55
56
57
        Should be one of ['train', 'dev', 'test', 'tiny']
        Default: train
    glove_embed_file : str, optional
        The path to pretrained glove embedding file.
        Default: None
58
    vocab_file : str, optional
Xiangkun Hu's avatar
Xiangkun Hu committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        Optional vocabulary file. If not given, the default vacabulary file is used.
        Default: None
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
        Whether to print out progress information. Default: True.

    Attributes
    ----------
    vocab : OrderedDict
        Vocabulary of the dataset
    trees : list
        A list of DGLGraph objects
    num_classes : int
        Number of classes for each node
    pretrained_emb: Tensor
        Pretrained glove embedding with respect the vocabulary.
    vocab_size : int
        The size of the vocabulary
    num_vocabs : int
        The size of the vocabulary

    Notes
    -----
    All the samples will be loaded and preprocessed in the memory first.

    Examples
    --------
    >>> # get dataset
    >>> train_data = SSTDataset()
    >>> dev_data = SSTDataset(mode='dev')
    >>> test_data = SSTDataset(mode='test')
    >>> tiny_data = SSTDataset(mode='tiny')
    >>>
    >>> len(train_data)
    8544
    >>> train_data.num_classes
    5
    >>> glove_embed = train_data.pretrained_emb
    >>> train_data.vocab_size
    19536
    >>> train_data[0]
104
105
106
    Graph(num_nodes=71, num_edges=70,
      ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})
Xiangkun Hu's avatar
Xiangkun Hu committed
107
108
109
110
111
    >>> for tree in train_data:
    ...     input_ids = tree.ndata['x']
    ...     labels = tree.ndata['y']
    ...     mask = tree.ndata['mask']
    ...     # your code here
112
    """
Xiangkun Hu's avatar
Xiangkun Hu committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    PAD_WORD = -1  # special pad word id
    UNK_WORD = -1  # out-of-vocabulary word id

    def __init__(self,
                 mode='train',
                 glove_embed_file=None,
                 vocab_file=None,
                 raw_dir=None,
                 force_reload=False,
                 verbose=False):
        assert mode in ['train', 'dev', 'test', 'tiny']
        _url = _get_dgl_url('dataset/sst.zip')
        self._glove_embed_file = glove_embed_file if mode == 'train' else None
127
        self.mode = mode
Xiangkun Hu's avatar
Xiangkun Hu committed
128
129
130
131
132
133
134
135
        self._vocab_file = vocab_file
        super(SSTDataset, self).__init__(name='sst',
                                         url=_url,
                                         raw_dir=raw_dir,
                                         force_reload=force_reload,
                                         verbose=verbose)

    def process(self):
136
        from nltk.corpus.reader import BracketParseCorpusReader
137
        # load vocab file
Xiangkun Hu's avatar
Xiangkun Hu committed
138
139
140
        self._vocab = OrderedDict()
        vocab_file = self._vocab_file if self._vocab_file is not None else os.path.join(self.raw_path, 'vocab.txt')
        with open(vocab_file, encoding='utf-8') as vf:
141
142
            for line in vf.readlines():
                line = line.strip()
Xiangkun Hu's avatar
Xiangkun Hu committed
143
                self._vocab[line] = len(self._vocab)
144
145

        # filter glove
Xiangkun Hu's avatar
Xiangkun Hu committed
146
        if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
147
            glove_emb = {}
Xiangkun Hu's avatar
Xiangkun Hu committed
148
            with open(self._glove_embed_file, 'r', encoding='utf-8') as pf:
149
150
                for line in pf.readlines():
                    sp = line.split(' ')
Xiangkun Hu's avatar
Xiangkun Hu committed
151
                    if sp[0].lower() in self._vocab:
152
                        glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]])
153
        files = ['{}.txt'.format(self.mode)]
Xiangkun Hu's avatar
Xiangkun Hu committed
154
        corpus = BracketParseCorpusReader(self.raw_path, files)
155
156
        sents = corpus.parsed_sents(files[0])

Xiangkun Hu's avatar
Xiangkun Hu committed
157
        # initialize with glove
158
159
        pretrained_emb = []
        fail_cnt = 0
Xiangkun Hu's avatar
Xiangkun Hu committed
160
161
        for line in self._vocab.keys():
            if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
162
163
164
165
                if not line.lower() in glove_emb:
                    fail_cnt += 1
                pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300)))

Xiangkun Hu's avatar
Xiangkun Hu committed
166
167
168
169
        self._pretrained_emb = None
        if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
            self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
            print('Miss word in GloVe {0:.4f}'.format(1.0 * fail_cnt / len(self._pretrained_emb)))
170
        # build trees
Xiangkun Hu's avatar
Xiangkun Hu committed
171
        self._trees = []
172
        for sent in sents:
Xiangkun Hu's avatar
Xiangkun Hu committed
173
            self._trees.append(self._build_tree(sent))
174

175
176
    def _build_tree(self, root):
        g = nx.DiGraph()
Xiangkun Hu's avatar
Xiangkun Hu committed
177

178
179
180
        def _rec_build(nid, node):
            for child in node:
                cid = g.number_of_nodes()
181
                if isinstance(child[0], str) or isinstance(child[0], bytes):
182
                    # leaf node
183
                    word = self.vocab.get(child[0].lower(), self.UNK_WORD)
184
                    g.add_node(cid, x=word, y=int(child.label()), mask=1)
185
                else:
Xiangkun Hu's avatar
Xiangkun Hu committed
186
                    g.add_node(cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0)
187
188
                    _rec_build(cid, child)
                g.add_edge(cid, nid)
Xiangkun Hu's avatar
Xiangkun Hu committed
189

190
        # add root
Xiangkun Hu's avatar
Xiangkun Hu committed
191
        g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0)
192
        _rec_build(0, root)
193
        ret = from_networkx(g, node_attrs=['x', 'y', 'mask'])
Minjie Wang's avatar
Minjie Wang committed
194
        return ret
195

Xiangkun Hu's avatar
Xiangkun Hu committed
196
197
    def has_cache(self):
        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
198
199
        vocab_path = os.path.join(self.save_path, 'vocab.pkl')
        return os.path.exists(graph_path) and os.path.exists(vocab_path)
Xiangkun Hu's avatar
Xiangkun Hu committed
200
201
202
203

    def save(self):
        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
        save_graphs(graph_path, self._trees)
204
205
206
207
208
        vocab_path = os.path.join(self.save_path, 'vocab.pkl')
        save_info(vocab_path, {'vocab': self.vocab})
        if self.pretrained_emb:
            emb_path = os.path.join(self.save_path, 'emb.pkl')
            save_info(emb_path, {'embed': self.pretrained_emb})
Xiangkun Hu's avatar
Xiangkun Hu committed
209
210
211

    def load(self):
        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
212
213
214
        vocab_path = os.path.join(self.save_path, 'vocab.pkl')
        emb_path = os.path.join(self.save_path, 'emb.pkl')

Xiangkun Hu's avatar
Xiangkun Hu committed
215
        self._trees = load_graphs(graph_path)[0]
216
        self._vocab = load_info(vocab_path)['vocab']
217
        self._pretrained_emb = None
218
219
        if os.path.exists(emb_path):
            self._pretrained_emb = load_info(emb_path)['embed']
Xiangkun Hu's avatar
Xiangkun Hu committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

    @property
    def trees(self):
        deprecate_property('dataset.trees', '[dataset[i] for i in len(dataset)]')
        return self._trees

    @property
    def vocab(self):
        r""" Vocabulary

        Returns
        -------
        OrderedDict
        """
        return self._vocab

    @property
    def pretrained_emb(self):
        r"""Pre-trained word embedding, if given."""
        return self._pretrained_emb

241
    def __getitem__(self, idx):
Xiangkun Hu's avatar
Xiangkun Hu committed
242
        r""" Get graph by index
Mufei Li's avatar
Mufei Li committed
243
244
245
246
247
248
249

        Parameters
        ----------
        idx : int

        Returns
        -------
250
251
252
253
254
255
256
        :class:`dgl.DGLGraph`

            graph structure, word id for each node, node labels and masks.

            - ``ndata['x']``: word id of the node
            - ``ndata['y']:`` label of the node
            - ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
Mufei Li's avatar
Mufei Li committed
257
        """
Xiangkun Hu's avatar
Xiangkun Hu committed
258
        return self._trees[idx]
259
260

    def __len__(self):
Xiangkun Hu's avatar
Xiangkun Hu committed
261
262
        r"""Number of graphs in the dataset."""
        return len(self._trees)
263

264
    @property
265
    def num_vocabs(self):
Xiangkun Hu's avatar
Xiangkun Hu committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        deprecate_property('dataset.num_vocabs', 'dataset.vocab_size')
        return self.vocab_size

    @property
    def vocab_size(self):
        r"""Vocabulary size."""
        return len(self._vocab)

    @property
    def num_classes(self):
        r"""Number of classes for each node."""
        return 5


SST = SSTDataset