nx_SST.py 2.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader as CorpusReader
import networkx as nx
import torch as th

class nx_BCT_Reader:
    # Binary Constituency Tree constructor for networkx
    def __init__(self, cuda=False,
                 fnames=['trees/train.txt', 'trees/dev.txt', 'trees/test.txt']):
        # fnames must be three items which means the file path of train, validation, test set, respectively.
        self.corpus = CorpusReader('.', fnames)
        self.train = self.corpus.parsed_sents(fnames[0])
        self.dev = self.corpus.parsed_sents(fnames[1])
        self.test = self.corpus.parsed_sents(fnames[2])

        self.vocab = {}
        def _rec(node):
            for child in node:
                if isinstance(child[0], str) and child[0] not in self.vocab:
                    self.vocab[child[0]] = len(self.vocab)
                elif isinstance(child, Tree):
                    _rec(child)
        for sent in self.train:
            _rec(sent)
        self.default = len(self.vocab) + 1

        self.LongTensor = th.cuda.LongTensor if cuda else th.LongTensor
        self.FloatTensor = th.cuda.FloatTensor if cuda else th.FloatTensor

    def create_BCT(self, root):
        self.node_cnt = 0
        self.G = nx.DiGraph()
        def _rec(node, nx_node, depth=0):
            for child in node:
                node_id = str(self.node_cnt) + '_' + str(depth+1)
                self.node_cnt += 1
#               if isinstance(child[0], str) or isinstance(child[0], unicode):
                if isinstance(child[0], str):
                    word = self.LongTensor([self.vocab.get(child[0], self.default)])
                    self.G.add_node(node_id, x=word, y=None)
                else:
                    label = self.FloatTensor([[0] * 5])
                    label[0, int(child.label())] = 1
                    self.G.add_node(node_id, x=None, y=label)
                    if isinstance(child, Tree): #check illegal trees
                        _rec(child, node_id)
                self.G.add_edge(node_id, nx_node)

        self.G.add_node('0_0', x=None, y=None) # add root into nx Graph
        _rec(root, '0_0')

        return self.G

    def generator(self, mode='train'):
        assert mode in ['train', 'dev', 'test']
        for s in self.__dict__[mode]:
            yield self.create_BCT(s)