preprocessing.py 7.01 KB
Newer Older
1
2
import os

3
4
import graph_tool as gt
import graph_tool.topology as gt_topology
5
6
7
8
import networkx as nx
import numpy as np
import torch
from ogb.graphproppred import DglGraphPropPredDataset
9
from tqdm import tqdm
10
11
12

from dgl.data.utils import load_graphs, save_graphs

13
14

def to_undirected(edge_index):
15
    row, col = edge_index.transpose(1, 0)
16
17
18
    row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
    edge_index = torch.stack([row, col], dim=0)

19
20
    return edge_index.transpose(1, 0).tolist()

21
22

def induced_edge_automorphism_orbits(edge_list):
23
24

    ##### node automorphism orbits #####
25
26
27
28
29
30
    graph = gt.Graph(directed=False)
    graph.add_edge_list(edge_list)
    gt.stats.remove_self_loops(graph)
    gt.stats.remove_parallel_edges(graph)

    # compute the node automorphism group
31
32
33
    aut_group = gt_topology.subgraph_isomorphism(
        graph, graph, induced=False, subgraph=True, generator=False
    )
34
35
36
37

    orbit_membership = {}
    for v in graph.get_vertices():
        orbit_membership[v] = v
38

39
40
41
42
43
    # whenever two nodes can be mapped via some automorphism, they are assigned the same orbit
    for aut in aut_group:
        for original, node in enumerate(aut):
            role = min(original, orbit_membership[node])
            orbit_membership[node] = role
44
45

    orbit_membership_list = [[], []]
46
47
48
49
50
    for node, om_curr in orbit_membership.items():
        orbit_membership_list[0].append(node)
        orbit_membership_list[1].append(om_curr)

    # make orbit list contiguous (i.e. 0,1,2,...O)
51
52
53
    _, contiguous_orbit_membership = np.unique(
        orbit_membership_list[1], return_inverse=True
    )
54

55
56
57
58
    orbit_membership = {
        node: contiguous_orbit_membership[i]
        for i, node in enumerate(orbit_membership_list[0])
    }
59
60
61
62
63
64
65
66

    aut_count = len(aut_group)

    ##### induced edge automorphism orbits (according to the node automorphism group) #####
    edge_orbit_partition = dict()
    edge_orbit_membership = dict()
    edge_orbits2inds = dict()
    ind = 0
67

68
69
70
    edge_list = to_undirected(torch.tensor(graph.get_edges()))

    # infer edge automorphisms from the node automorphisms
71
72
73
74
    for i, edge in enumerate(edge_list):
        edge_orbit = frozenset(
            [orbit_membership[edge[0]], orbit_membership[edge[1]]]
        )
75
76
77
78
79
80
81
82
83
84
        if edge_orbit not in edge_orbits2inds:
            edge_orbits2inds[edge_orbit] = ind
            ind_edge_orbit = ind
            ind += 1
        else:
            ind_edge_orbit = edge_orbits2inds[edge_orbit]

        if ind_edge_orbit not in edge_orbit_partition:
            edge_orbit_partition[ind_edge_orbit] = [tuple(edge)]
        else:
85
            edge_orbit_partition[ind_edge_orbit] += [tuple(edge)]
86
87
88

        edge_orbit_membership[i] = ind_edge_orbit

89
90
91
92
93
94
95
96
    print(
        "Edge orbit partition of given substructure: {}".format(
            edge_orbit_partition
        )
    )
    print("Number of edge orbits: {}".format(len(edge_orbit_partition)))
    print("Graph (node) automorphism count: {}".format(aut_count))

97
98
    return graph, edge_orbit_partition, edge_orbit_membership, aut_count

99

100
def subgraph_isomorphism_edge_counts(edge_index, subgraph_dict):
101

102
    ##### edge structural identifiers #####
103
104

    edge_index = edge_index.transpose(1, 0).cpu().numpy()
105
    edge_dict = {}
106
    for i, edge in enumerate(edge_index):
107
        edge_dict[tuple(edge)] = i
108
109
110
111

    subgraph_edges = to_undirected(
        torch.tensor(subgraph_dict["subgraph"].get_edges().tolist())
    )
112
113
114
115

    G_gt = gt.Graph(directed=False)
    G_gt.add_edge_list(list(edge_index))
    gt.stats.remove_self_loops(G_gt)
116
117
    gt.stats.remove_parallel_edges(G_gt)

118
    # compute all subgraph isomorphisms
119
120
121
122
123
124
125
126
127
128
129
130
    sub_iso = gt_topology.subgraph_isomorphism(
        subgraph_dict["subgraph"],
        G_gt,
        induced=True,
        subgraph=True,
        generator=True,
    )

    counts = np.zeros(
        (edge_index.shape[0], len(subgraph_dict["orbit_partition"]))
    )

131
132
    for sub_iso_curr in sub_iso:
        mapping = sub_iso_curr.get_array()
133
134
        for i, edge in enumerate(subgraph_edges):

135
            # for every edge in the graph H, find the edge in the subgraph G_S to which it is mapped
136
            # (by finding where its endpoints are matched).
137
138
            # Then, increase the count of the matched edge w.r.t. the corresponding orbit
            # Repeat for the reverse edge (the one with the opposite direction)
139
140

            edge_orbit = subgraph_dict["orbit_membership"][i]
141
142
            mapped_edge = tuple([mapping[edge[0]], mapping[edge[1]]])
            counts[edge_dict[mapped_edge], edge_orbit] += 1
143
144
145

    counts = counts / subgraph_dict["aut_count"]

146
    counts = torch.tensor(counts)
147

148
149
    return counts

150

151
def prepare_dataset(name):
152

153
154
155
    # maximum size of cycle graph
    k = 8

156
157
    path = os.path.join("./", "dataset", name)
    data_folder = os.path.join(path, "processed")
158
    os.makedirs(data_folder, exist_ok=True)
159
160
161
162
163

    data_file = os.path.join(
        data_folder, "cycle_graph_induced_{}.bin".format(k)
    )

164
165
166
167
    # try to load
    if os.path.exists(data_file):  # load
        print("Loading dataset from {}".format(data_file))
        g_list, split_idx = load_graphs(data_file)
168
    else:  # generate
169
170
171
172
173
174
        g_list, split_idx = generate_dataset(path, name)
        print("Saving dataset to {}".format(data_file))
        save_graphs(data_file, g_list, split_idx)

    return g_list, split_idx

175

176
177
178
179
180
181
182
183
184
185
186
def generate_dataset(path, name):

    ### compute the orbits of each substructure in the list, as well as the node automorphism count
    subgraph_dicts = []

    edge_lists = []
    for k in range(3, 8 + 1):
        graphs_nx = nx.cycle_graph(k)
        edge_lists.append(list(graphs_nx.edges))

    for edge_list in edge_lists:
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        (
            subgraph,
            orbit_partition,
            orbit_membership,
            aut_count,
        ) = induced_edge_automorphism_orbits(edge_list=edge_list)
        subgraph_dicts.append(
            {
                "subgraph": subgraph,
                "orbit_partition": orbit_partition,
                "orbit_membership": orbit_membership,
                "aut_count": aut_count,
            }
        )

202
203
204
    ### load and preprocess dataset
    dataset = DglGraphPropPredDataset(name=name, root=path)
    split_idx = dataset.get_idx_split()
205

206
207
208
209
210
211
212
213
214
215
216
217
    # computation of subgraph isomorphisms & creation of data structure
    graphs_dgl = list()
    split_idx["label"] = []
    for i, datapoint in tqdm(enumerate(dataset)):
        g, label = datapoint
        g = _prepare(g, subgraph_dicts)
        graphs_dgl.append(g)
        split_idx["label"].append(label)

    split_idx["label"] = torch.stack(split_idx["label"])

    return graphs_dgl, split_idx
218
219


220
221
222
def _prepare(g, subgraph_dicts):

    edge_index = torch.stack(g.edges())
223

224
225
226
    identifiers = None
    for subgraph_dict in subgraph_dicts:
        counts = subgraph_isomorphism_edge_counts(edge_index, subgraph_dict)
227
228
229
230
231
232
233
        identifiers = (
            counts
            if identifiers is None
            else torch.cat((identifiers, counts), 1)
        )

    g.edata["subgraph_counts"] = identifiers.long()
234
235
236

    return g

237
238
239

if __name__ == "__main__":
    prepare_dataset("ogbg-molpcba")