preprocessing.py 7.01 KB
Newer Older
1
2
import os

3
4
import graph_tool as gt
import graph_tool.topology as gt_topology
5
6
7
8
9
import networkx as nx
import numpy as np
import torch

from dgl.data.utils import load_graphs, save_graphs
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
11
from ogb.graphproppred import DglGraphPropPredDataset
from tqdm import tqdm
12

13
14

def to_undirected(edge_index):
15
    row, col = edge_index.transpose(1, 0)
16
17
18
    row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
    edge_index = torch.stack([row, col], dim=0)

19
20
    return edge_index.transpose(1, 0).tolist()

21
22

def induced_edge_automorphism_orbits(edge_list):
23
    ##### node automorphism orbits #####
24
25
26
27
28
29
    graph = gt.Graph(directed=False)
    graph.add_edge_list(edge_list)
    gt.stats.remove_self_loops(graph)
    gt.stats.remove_parallel_edges(graph)

    # compute the node automorphism group
30
31
32
    aut_group = gt_topology.subgraph_isomorphism(
        graph, graph, induced=False, subgraph=True, generator=False
    )
33
34
35
36

    orbit_membership = {}
    for v in graph.get_vertices():
        orbit_membership[v] = v
37

38
39
40
41
42
    # whenever two nodes can be mapped via some automorphism, they are assigned the same orbit
    for aut in aut_group:
        for original, node in enumerate(aut):
            role = min(original, orbit_membership[node])
            orbit_membership[node] = role
43
44

    orbit_membership_list = [[], []]
45
46
47
48
49
    for node, om_curr in orbit_membership.items():
        orbit_membership_list[0].append(node)
        orbit_membership_list[1].append(om_curr)

    # make orbit list contiguous (i.e. 0,1,2,...O)
50
51
52
    _, contiguous_orbit_membership = np.unique(
        orbit_membership_list[1], return_inverse=True
    )
53

54
55
56
57
    orbit_membership = {
        node: contiguous_orbit_membership[i]
        for i, node in enumerate(orbit_membership_list[0])
    }
58
59
60
61
62
63
64
65

    aut_count = len(aut_group)

    ##### induced edge automorphism orbits (according to the node automorphism group) #####
    edge_orbit_partition = dict()
    edge_orbit_membership = dict()
    edge_orbits2inds = dict()
    ind = 0
66

67
68
69
    edge_list = to_undirected(torch.tensor(graph.get_edges()))

    # infer edge automorphisms from the node automorphisms
70
71
72
73
    for i, edge in enumerate(edge_list):
        edge_orbit = frozenset(
            [orbit_membership[edge[0]], orbit_membership[edge[1]]]
        )
74
75
76
77
78
79
80
81
82
83
        if edge_orbit not in edge_orbits2inds:
            edge_orbits2inds[edge_orbit] = ind
            ind_edge_orbit = ind
            ind += 1
        else:
            ind_edge_orbit = edge_orbits2inds[edge_orbit]

        if ind_edge_orbit not in edge_orbit_partition:
            edge_orbit_partition[ind_edge_orbit] = [tuple(edge)]
        else:
84
            edge_orbit_partition[ind_edge_orbit] += [tuple(edge)]
85
86
87

        edge_orbit_membership[i] = ind_edge_orbit

88
89
90
91
92
93
94
95
    print(
        "Edge orbit partition of given substructure: {}".format(
            edge_orbit_partition
        )
    )
    print("Number of edge orbits: {}".format(len(edge_orbit_partition)))
    print("Graph (node) automorphism count: {}".format(aut_count))

96
97
    return graph, edge_orbit_partition, edge_orbit_membership, aut_count

98

99
100
def subgraph_isomorphism_edge_counts(edge_index, subgraph_dict):
    ##### edge structural identifiers #####
101
102

    edge_index = edge_index.transpose(1, 0).cpu().numpy()
103
    edge_dict = {}
104
    for i, edge in enumerate(edge_index):
105
        edge_dict[tuple(edge)] = i
106
107
108
109

    subgraph_edges = to_undirected(
        torch.tensor(subgraph_dict["subgraph"].get_edges().tolist())
    )
110
111
112
113

    G_gt = gt.Graph(directed=False)
    G_gt.add_edge_list(list(edge_index))
    gt.stats.remove_self_loops(G_gt)
114
115
    gt.stats.remove_parallel_edges(G_gt)

116
    # compute all subgraph isomorphisms
117
118
119
120
121
122
123
124
125
126
127
128
    sub_iso = gt_topology.subgraph_isomorphism(
        subgraph_dict["subgraph"],
        G_gt,
        induced=True,
        subgraph=True,
        generator=True,
    )

    counts = np.zeros(
        (edge_index.shape[0], len(subgraph_dict["orbit_partition"]))
    )

129
130
    for sub_iso_curr in sub_iso:
        mapping = sub_iso_curr.get_array()
131
        for i, edge in enumerate(subgraph_edges):
132
            # for every edge in the graph H, find the edge in the subgraph G_S to which it is mapped
133
            # (by finding where its endpoints are matched).
134
135
            # Then, increase the count of the matched edge w.r.t. the corresponding orbit
            # Repeat for the reverse edge (the one with the opposite direction)
136
137

            edge_orbit = subgraph_dict["orbit_membership"][i]
138
139
            mapped_edge = tuple([mapping[edge[0]], mapping[edge[1]]])
            counts[edge_dict[mapped_edge], edge_orbit] += 1
140
141
142

    counts = counts / subgraph_dict["aut_count"]

143
    counts = torch.tensor(counts)
144

145
146
    return counts

147

148
149
150
151
def prepare_dataset(name):
    # maximum size of cycle graph
    k = 8

152
153
    path = os.path.join("./", "dataset", name)
    data_folder = os.path.join(path, "processed")
154
    os.makedirs(data_folder, exist_ok=True)
155
156
157
158
159

    data_file = os.path.join(
        data_folder, "cycle_graph_induced_{}.bin".format(k)
    )

160
161
162
163
    # try to load
    if os.path.exists(data_file):  # load
        print("Loading dataset from {}".format(data_file))
        g_list, split_idx = load_graphs(data_file)
164
    else:  # generate
165
166
167
168
169
170
        g_list, split_idx = generate_dataset(path, name)
        print("Saving dataset to {}".format(data_file))
        save_graphs(data_file, g_list, split_idx)

    return g_list, split_idx

171

172
173
174
175
176
177
178
179
180
181
def generate_dataset(path, name):
    ### compute the orbits of each substructure in the list, as well as the node automorphism count
    subgraph_dicts = []

    edge_lists = []
    for k in range(3, 8 + 1):
        graphs_nx = nx.cycle_graph(k)
        edge_lists.append(list(graphs_nx.edges))

    for edge_list in edge_lists:
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        (
            subgraph,
            orbit_partition,
            orbit_membership,
            aut_count,
        ) = induced_edge_automorphism_orbits(edge_list=edge_list)
        subgraph_dicts.append(
            {
                "subgraph": subgraph,
                "orbit_partition": orbit_partition,
                "orbit_membership": orbit_membership,
                "aut_count": aut_count,
            }
        )

197
198
199
    ### load and preprocess dataset
    dataset = DglGraphPropPredDataset(name=name, root=path)
    split_idx = dataset.get_idx_split()
200

201
202
203
204
205
206
207
208
209
210
211
212
    # computation of subgraph isomorphisms & creation of data structure
    graphs_dgl = list()
    split_idx["label"] = []
    for i, datapoint in tqdm(enumerate(dataset)):
        g, label = datapoint
        g = _prepare(g, subgraph_dicts)
        graphs_dgl.append(g)
        split_idx["label"].append(label)

    split_idx["label"] = torch.stack(split_idx["label"])

    return graphs_dgl, split_idx
213
214


215
216
def _prepare(g, subgraph_dicts):
    edge_index = torch.stack(g.edges())
217

218
219
220
    identifiers = None
    for subgraph_dict in subgraph_dicts:
        counts = subgraph_isomorphism_edge_counts(edge_index, subgraph_dict)
221
222
223
224
225
226
227
        identifiers = (
            counts
            if identifiers is None
            else torch.cat((identifiers, counts), 1)
        )

    g.edata["subgraph_counts"] = identifiers.long()
228
229
230

    return g

231
232
233

if __name__ == "__main__":
    prepare_dataset("ogbg-molpcba")