dataloader.py 11.4 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os
import pickle
from collections import Counter

import numpy as np
import torch
from torch.utils import data

logger = logging.getLogger("nni.textnas")


class PTBTree:
    WORD_TO_WORD_MAPPING = {
        "{": "-LCB-",
        "}": "-RCB-"
    }

    def __init__(self):
        self.subtrees = []
        self.word = None
        self.label = ""
        self.parent = None
        self.span = (-1, -1)
        self.word_vector = None  # HOS, store dx1 RNN word vector
        self.prediction = None  # HOS, store Kx1 prediction vector

    def is_leaf(self):
        return len(self.subtrees) == 0

    def set_by_text(self, text, pos=0, left=0):
        depth = 0
        right = left
        for i in range(pos + 1, len(text)):
            char = text[i]
            # update the depth
            if char == "(":
                depth += 1
                if depth == 1:
                    subtree = PTBTree()
                    subtree.parent = self
                    subtree.set_by_text(text, i, right)
                    right = subtree.span[1]
                    self.span = (left, right)
                    self.subtrees.append(subtree)
            elif char == ")":
                depth -= 1
                if len(self.subtrees) == 0:
                    pos = i
                    for j in range(i, 0, -1):
                        if text[j] == " ":
                            pos = j
                            break
                    self.word = text[pos + 1:i]
                    self.span = (left, left + 1)

            # we've reached the end of the category that is the root of this subtree
            if depth == 0 and char == " " and self.label == "":
                self.label = text[pos + 1:i]
            # we've reached the end of the scope for this bracket
            if depth < 0:
                break

        # Fix some issues with variation in output, and one error in the treebank
        # for a word with a punctuation POS
        self.standardise_node()

    def standardise_node(self):
        if self.word in self.WORD_TO_WORD_MAPPING:
            self.word = self.WORD_TO_WORD_MAPPING[self.word]

    def __repr__(self, single_line=True, depth=0):
        ans = ""
        if not single_line and depth > 0:
            ans = "\n" + depth * "\t"
        ans += "(" + self.label
        if self.word is not None:
            ans += " " + self.word
        for subtree in self.subtrees:
            if single_line:
                ans += " "
            ans += subtree.__repr__(single_line, depth + 1)
        ans += ")"
        return ans


def read_tree(source):
    cur_text = []
    depth = 0
    while True:
        line = source.readline()
        # Check if we are out of input
        if line == "":
            return None
        # strip whitespace and only use if this contains something
        line = line.strip()
        if line == "":
            continue
        cur_text.append(line)
        # Update depth
        for char in line:
            if char == "(":
                depth += 1
            elif char == ")":
                depth -= 1
        # At depth 0 we have a complete tree
        if depth == 0:
            tree = PTBTree()
            tree.set_by_text(" ".join(cur_text))
            return tree
    return None


def read_trees(source, max_sents=-1):
    with open(source) as fp:
        trees = []
        while True:
            tree = read_tree(fp)
            if tree is None:
                break
            trees.append(tree)
            if len(trees) >= max_sents > 0:
                break
        return trees


class SSTDataset(data.Dataset):
    def __init__(self, sents, mask, labels):
        self.sents = sents
        self.labels = labels
        self.mask = mask

    def __getitem__(self, index):
        return (self.sents[index], self.mask[index]), self.labels[index]

    def __len__(self):
        return len(self.sents)


def sst_get_id_input(content, word_id_dict, max_input_length):
    words = content.split(" ")
    sentence = [word_id_dict["<pad>"]] * max_input_length
    mask = [0] * max_input_length
    unknown = word_id_dict["<unknown>"]
    for i, word in enumerate(words[:max_input_length]):
        sentence[i] = word_id_dict.get(word, unknown)
        mask[i] = 1
    return sentence, mask


def sst_get_phrases(trees, sample_ratio=1.0, is_binary=False, only_sentence=False):
    all_phrases = []
    for tree in trees:
        if only_sentence:
            sentence = get_sentence_by_tree(tree)
            label = int(tree.label)
            pair = (sentence, label)
            all_phrases.append(pair)
        else:
            phrases = get_phrases_by_tree(tree)
            sentence = get_sentence_by_tree(tree)
            pair = (sentence, int(tree.label))
            all_phrases.append(pair)
            all_phrases += phrases
    if sample_ratio < 1.:
        np.random.shuffle(all_phrases)
    result_phrases = []
    for pair in all_phrases:
        if is_binary:
            phrase, label = pair
            if label <= 1:
                pair = (phrase, 0)
            elif label >= 3:
                pair = (phrase, 1)
            else:
                continue
        if sample_ratio == 1.:
            result_phrases.append(pair)
        else:
            rand_portion = np.random.random()
            if rand_portion < sample_ratio:
                result_phrases.append(pair)
    return result_phrases


def get_phrases_by_tree(tree):
    phrases = []
    if tree is None:
        return phrases
    if tree.is_leaf():
        pair = (tree.word, int(tree.label))
        phrases.append(pair)
        return phrases
    left_child_phrases = get_phrases_by_tree(tree.subtrees[0])
    right_child_phrases = get_phrases_by_tree(tree.subtrees[1])
    phrases.extend(left_child_phrases)
    phrases.extend(right_child_phrases)
    sentence = get_sentence_by_tree(tree)
    pair = (sentence, int(tree.label))
    phrases.append(pair)
    return phrases


def get_sentence_by_tree(tree):
    if tree is None:
        return ""
    if tree.is_leaf():
        return tree.word
    left_sentence = get_sentence_by_tree(tree.subtrees[0])
    right_sentence = get_sentence_by_tree(tree.subtrees[1])
    sentence = left_sentence + " " + right_sentence
    return sentence.strip()


def get_word_id_dict(word_num_dict, word_id_dict, min_count):
    z = [k for k in sorted(word_num_dict.keys())]
    for word in z:
        count = word_num_dict[word]
        if count >= min_count:
            index = len(word_id_dict)
            if word not in word_id_dict:
                word_id_dict[word] = index
    return word_id_dict


def load_word_num_dict(phrases, word_num_dict):
    for sentence, _ in phrases:
        words = sentence.split(" ")
        for cur_word in words:
            word = cur_word.strip()
            word_num_dict[word] += 1
    return word_num_dict


def init_trainable_embedding(embedding_path, word_id_dict, embed_dim=300):
    word_embed_model = load_glove_model(embedding_path, embed_dim)
    assert word_embed_model["pool"].shape[1] == embed_dim
    embedding = np.random.random([len(word_id_dict), embed_dim]).astype(np.float32) / 2.0 - 0.25
    embedding[0] = np.zeros(embed_dim)  # PAD
    embedding[1] = (np.random.rand(embed_dim) - 0.5) / 2  # UNK
244
245
    for word in sorted(word_id_dict.keys()):
        idx = word_id_dict[word]
Yuge Zhang's avatar
Yuge Zhang committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        if idx == 0 or idx == 1:
            continue
        if word in word_embed_model["mapping"]:
            embedding[idx] = word_embed_model["pool"][word_embed_model["mapping"][word]]
        else:
            embedding[idx] = np.random.rand(embed_dim) / 2.0 - 0.25
    return embedding


def sst_get_trainable_data(phrases, word_id_dict, max_input_length):
    texts, labels, mask = [], [], []

    for phrase, label in phrases:
        if not phrase.split():
            continue
        phrase_split, mask_split = sst_get_id_input(phrase, word_id_dict, max_input_length)
        texts.append(phrase_split)
        labels.append(int(label))
        mask.append(mask_split)  # field_input is mask
    labels = np.array(labels, dtype=np.int64)
    texts = np.reshape(texts, [-1, max_input_length]).astype(np.int32)
    mask = np.reshape(mask, [-1, max_input_length]).astype(np.int32)

    return SSTDataset(texts, mask, labels)


def load_glove_model(filename, embed_dim):
    if os.path.exists(filename + ".cache"):
        logger.info("Found cache. Loading...")
        with open(filename + ".cache", "rb") as fp:
            return pickle.load(fp)
    embedding = {"mapping": dict(), "pool": []}
    with open(filename) as f:
        for i, line in enumerate(f):
            line = line.rstrip("\n")
            vocab_word, *vec = line.rsplit(" ", maxsplit=embed_dim)
            assert len(vec) == 300, "Unexpected line: '%s'" % line
            embedding["pool"].append(np.array(list(map(float, vec)), dtype=np.float32))
            embedding["mapping"][vocab_word] = i
    embedding["pool"] = np.stack(embedding["pool"])
    with open(filename + ".cache", "wb") as fp:
        pickle.dump(embedding, fp)
    return embedding


def read_data_sst(data_path, max_input_length=64, min_count=1, train_with_valid=False,
                  train_ratio=1., valid_ratio=1., is_binary=False, only_sentence=False):
    word_id_dict = dict()
    word_num_dict = Counter()

    sst_path = os.path.join(data_path, "sst")
    logger.info("Reading SST data...")
    train_file_name = os.path.join(sst_path, "trees", "train.txt")
    valid_file_name = os.path.join(sst_path, "trees", "dev.txt")
    test_file_name = os.path.join(sst_path, "trees", "test.txt")
    train_trees = read_trees(train_file_name)
    train_phrases = sst_get_phrases(train_trees, train_ratio, is_binary, only_sentence)
    logger.info("Finish load train phrases.")
    valid_trees = read_trees(valid_file_name)
    valid_phrases = sst_get_phrases(valid_trees, valid_ratio, is_binary, only_sentence)
    logger.info("Finish load valid phrases.")
    if train_with_valid:
        train_phrases += valid_phrases
    test_trees = read_trees(test_file_name)
    test_phrases = sst_get_phrases(test_trees, valid_ratio, is_binary, only_sentence=True)
    logger.info("Finish load test phrases.")

    # get word_id_dict
    word_id_dict["<pad>"] = 0
    word_id_dict["<unknown>"] = 1
    load_word_num_dict(train_phrases, word_num_dict)
    logger.info("Finish load train words: %d.", len(word_num_dict))
    load_word_num_dict(valid_phrases, word_num_dict)
    load_word_num_dict(test_phrases, word_num_dict)
    logger.info("Finish load valid+test words: %d.", len(word_num_dict))
    word_id_dict = get_word_id_dict(word_num_dict, word_id_dict, min_count)
    logger.info("After trim vocab length: %d.", len(word_id_dict))

    logger.info("Loading embedding...")
    embedding = init_trainable_embedding(os.path.join(data_path, "glove.840B.300d.txt"), word_id_dict)
    logger.info("Finish initialize word embedding.")

    dataset_train = sst_get_trainable_data(train_phrases, word_id_dict, max_input_length)
    logger.info("Loaded %d training samples.", len(dataset_train))
    dataset_valid = sst_get_trainable_data(valid_phrases, word_id_dict, max_input_length)
    logger.info("Loaded %d validation samples.", len(dataset_valid))
    dataset_test = sst_get_trainable_data(test_phrases, word_id_dict, max_input_length)
    logger.info("Loaded %d test samples.", len(dataset_test))

    return dataset_train, dataset_valid, dataset_test, torch.from_numpy(embedding)