dataloader.py 4.9 KB
Newer Older
1
import os
2

3
4
import numpy as np
import scipy.io as sio
5
6
7
import torch as th

import dgl
8
from dgl.data import DGLBuiltinDataset
9
from dgl.data.utils import _get_dgl_url, load_graphs, save_graphs
10
11
12


class GASDataset(DGLBuiltinDataset):
13
    file_urls = {"pol": "dataset/GASPOL.zip", "gos": "dataset/GASGOS.zip"}
14

15
16
17
18
    def __init__(
        self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1
    ):
        assert name in ["gos", "pol"], "Only supports 'gos' or 'pol'."
19
20
21
22
        self.seed = random_seed
        self.train_size = train_size
        self.val_size = val_size
        url = _get_dgl_url(self.file_urls[name])
23
        super(GASDataset, self).__init__(name=name, url=url, raw_dir=raw_dir)
24
25
26

    def process(self):
        """process raw data to graph, labels and masks"""
27
28
29
        data = sio.loadmat(
            os.path.join(self.raw_path, f"{self.name}_retweet_graph.mat")
        )
30

31
        adj = data["graph"].tocoo()
32
        num_edges = len(adj.row)
33
        row, col = adj.row[: int(num_edges / 2)], adj.col[: int(num_edges / 2)]
34

35
36
37
38
        graph = dgl.graph(
            (np.concatenate((row, col)), np.concatenate((col, row)))
        )
        news_labels = data["label"].squeeze()
39
40
        num_news = len(news_labels)

41
42
43
44
45
46
        node_feature = np.load(
            os.path.join(self.raw_path, f"{self.name}_node_feature.npy")
        )
        edge_feature = np.load(
            os.path.join(self.raw_path, f"{self.name}_edge_feature.npy")
        )[: int(num_edges / 2)]
47

48
49
        graph.ndata["feat"] = th.tensor(node_feature)
        graph.edata["feat"] = th.tensor(np.tile(edge_feature, (2, 1)))
50
51
52
        pos_news = news_labels.nonzero()[0]

        edge_labels = th.zeros(num_edges)
53
54
55
        edge_labels[graph.in_edges(pos_news, form="eid")] = 1
        edge_labels[graph.out_edges(pos_news, form="eid")] = 1
        graph.edata["label"] = edge_labels
56
57
58
59
60

        ntypes = th.ones(graph.num_nodes(), dtype=int)
        etypes = th.ones(graph.num_edges(), dtype=int)

        ntypes[graph.nodes() < num_news] = 0
61
        etypes[: int(num_edges / 2)] = 0
62

63
64
        graph.ndata["_TYPE"] = ntypes
        graph.edata["_TYPE"] = etypes
65

66
        hg = dgl.to_heterogeneous(graph, ["v", "u"], ["forward", "backward"])
67
68
69
70
71
72
        self._random_split(hg, self.seed, self.train_size, self.val_size)

        self.graph = hg

    def save(self):
        """save the graph list and the labels"""
73
        graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
74
75
76
        save_graphs(str(graph_path), self.graph)

    def has_cache(self):
77
78
        """check whether there are processed data in `self.save_path`"""
        graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
79
80
81
82
        return os.path.exists(graph_path)

    def load(self):
        """load processed data from directory `self.save_path`"""
83
        graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
84
85
86
87
88
89
90
91
92
93

        graph, _ = load_graphs(str(graph_path))
        self.graph = graph[0]

    @property
    def num_classes(self):
        """Number of classes for each graph, i.e. number of prediction tasks."""
        return 2

    def __getitem__(self, idx):
94
        r"""Get graph object
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        Parameters
        ----------
        idx : int
            Item index
        Returns
        -------
        :class:`dgl.DGLGraph`
        """
        assert idx == 0, "This dataset has only one graph"
        return self.graph

    def __len__(self):
        r"""Number of data examples
        Return
        -------
        int
        """
        return len(self.graph)

    def _random_split(self, graph, seed=717, train_size=0.7, val_size=0.1):
        """split the dataset into training set, validation set and testing set"""

117
118
        assert 0 <= train_size + val_size <= 1, (
            "The sum of valid training set size and validation set size "
119
            "must between 0 and 1 (inclusive)."
120
        )
121

122
        num_edges = graph.num_edges(etype="forward")
123
124
125
        index = np.arange(num_edges)

        index = np.random.RandomState(seed).permutation(index)
126
127
128
129
130
        train_idx = index[: int(train_size * num_edges)]
        val_idx = index[num_edges - int(val_size * num_edges) :]
        test_idx = index[
            int(train_size * num_edges) : num_edges - int(val_size * num_edges)
        ]
131
132
133
134
135
136
        train_mask = np.zeros(num_edges, dtype=np.bool)
        val_mask = np.zeros(num_edges, dtype=np.bool)
        test_mask = np.zeros(num_edges, dtype=np.bool)
        train_mask[train_idx] = True
        val_mask[val_idx] = True
        test_mask[test_idx] = True
137
138
139
140
141
142
        graph.edges["forward"].data["train_mask"] = th.tensor(train_mask)
        graph.edges["forward"].data["val_mask"] = th.tensor(val_mask)
        graph.edges["forward"].data["test_mask"] = th.tensor(test_mask)
        graph.edges["backward"].data["train_mask"] = th.tensor(train_mask)
        graph.edges["backward"].data["val_mask"] = th.tensor(val_mask)
        graph.edges["backward"].data["test_mask"] = th.tensor(test_mask)