dataloader.py 4.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import dgl
import torch as th
import numpy as np
import scipy.io as sio
from dgl.data import DGLBuiltinDataset
from dgl.data.utils import save_graphs, load_graphs, _get_dgl_url


class GASDataset(DGLBuiltinDataset):
    file_urls = {
        'pol': 'dataset/GASPOL.zip',
        'gos': 'dataset/GASGOS.zip'
    }

    def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1):
        assert name in ['gos', 'pol'], "Only supports 'gos' or 'pol'."
        self.seed = random_seed
        self.train_size = train_size
        self.val_size = val_size
        url = _get_dgl_url(self.file_urls[name])
        super(GASDataset, self).__init__(name=name,
                                         url=url,
                                         raw_dir=raw_dir)

    def process(self):
        """process raw data to graph, labels and masks"""
        data = sio.loadmat(os.path.join(self.raw_path, f'{self.name}_retweet_graph.mat'))

        adj = data['graph'].tocoo()
        num_edges = len(adj.row)
        row, col = adj.row[:int(num_edges/2)], adj.col[:int(num_edges/2)]

        graph = dgl.graph((np.concatenate((row, col)), np.concatenate((col, row))))
        news_labels = data['label'].squeeze()
        num_news = len(news_labels)

        node_feature = np.load(os.path.join(self.raw_path, f'{self.name}_node_feature.npy'))
        edge_feature = np.load(os.path.join(self.raw_path, f'{self.name}_edge_feature.npy'))[:int(num_edges/2)]

        graph.ndata['feat'] = th.tensor(node_feature)
        graph.edata['feat'] = th.tensor(np.tile(edge_feature, (2, 1)))
        pos_news = news_labels.nonzero()[0]

        edge_labels = th.zeros(num_edges)
        edge_labels[graph.in_edges(pos_news, form='eid')] = 1
        edge_labels[graph.out_edges(pos_news, form='eid')] = 1
        graph.edata['label'] = edge_labels

        ntypes = th.ones(graph.num_nodes(), dtype=int)
        etypes = th.ones(graph.num_edges(), dtype=int)

        ntypes[graph.nodes() < num_news] = 0
        etypes[:int(num_edges/2)] = 0

        graph.ndata['_TYPE'] = ntypes
        graph.edata['_TYPE'] = etypes

        hg = dgl.to_heterogeneous(graph, ['v', 'u'], ['forward', 'backward'])
        self._random_split(hg, self.seed, self.train_size, self.val_size)

        self.graph = hg

    def save(self):
        """save the graph list and the labels"""
        graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
        save_graphs(str(graph_path), self.graph)

    def has_cache(self):
        """ check whether there are processed data in `self.save_path` """
        graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
        return os.path.exists(graph_path)

    def load(self):
        """load processed data from directory `self.save_path`"""
        graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')

        graph, _ = load_graphs(str(graph_path))
        self.graph = graph[0]

    @property
    def num_classes(self):
        """Number of classes for each graph, i.e. number of prediction tasks."""
        return 2

    def __getitem__(self, idx):
        r""" Get graph object
        Parameters
        ----------
        idx : int
            Item index
        Returns
        -------
        :class:`dgl.DGLGraph`
        """
        assert idx == 0, "This dataset has only one graph"
        return self.graph

    def __len__(self):
        r"""Number of data examples
        Return
        -------
        int
        """
        return len(self.graph)

    def _random_split(self, graph, seed=717, train_size=0.7, val_size=0.1):
        """split the dataset into training set, validation set and testing set"""

        assert 0 <= train_size + val_size <= 1, \
            "The sum of valid training set size and validation set size " \
            "must between 0 and 1 (inclusive)."

        num_edges = graph.num_edges(etype='forward')
        index = np.arange(num_edges)

        index = np.random.RandomState(seed).permutation(index)
        train_idx = index[:int(train_size * num_edges)]
        val_idx = index[num_edges - int(val_size * num_edges):]
        test_idx = index[int(train_size * num_edges):num_edges - int(val_size * num_edges)]
lisj's avatar
lisj committed
121
122
123
        train_mask = np.zeros(num_edges, dtype=bool)
        val_mask = np.zeros(num_edges, dtype=bool)
        test_mask = np.zeros(num_edges, dtype=bool)
124
125
126
127
128
129
130
131
132
133
        train_mask[train_idx] = True
        val_mask[val_idx] = True
        test_mask[test_idx] = True
        graph.edges['forward'].data['train_mask'] = th.tensor(train_mask)
        graph.edges['forward'].data['val_mask'] = th.tensor(val_mask)
        graph.edges['forward'].data['test_mask'] = th.tensor(test_mask)
        graph.edges['backward'].data['train_mask'] = th.tensor(train_mask)
        graph.edges['backward'].data['val_mask'] = th.tensor(val_mask)
        graph.edges['backward'].data['test_mask'] = th.tensor(test_mask)