preprocessing.py 6.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from ogb.graphproppred import DglGraphPropPredDataset
import torch
import numpy as np
import networkx as nx
import graph_tool as gt
import graph_tool.topology as gt_topology
from tqdm import tqdm
import os
from dgl.data.utils import save_graphs, load_graphs

def to_undirected(edge_index):
    row, col = edge_index.transpose(1,0)
    row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
    edge_index = torch.stack([row, col], dim=0)

    return edge_index.transpose(1,0).tolist()

def induced_edge_automorphism_orbits(edge_list):
        
    ##### node automorphism orbits ##### 
    graph = gt.Graph(directed=False)
    graph.add_edge_list(edge_list)
    gt.stats.remove_self_loops(graph)
    gt.stats.remove_parallel_edges(graph)

    # compute the node automorphism group
    aut_group = gt_topology.subgraph_isomorphism(graph, graph, induced=False, subgraph=True, generator=False)

    orbit_membership = {}
    for v in graph.get_vertices():
        orbit_membership[v] = v
    
    # whenever two nodes can be mapped via some automorphism, they are assigned the same orbit
    for aut in aut_group:
        for original, node in enumerate(aut):
            role = min(original, orbit_membership[node])
            orbit_membership[node] = role
        
    orbit_membership_list = [[],[]]
    for node, om_curr in orbit_membership.items():
        orbit_membership_list[0].append(node)
        orbit_membership_list[1].append(om_curr)

    # make orbit list contiguous (i.e. 0,1,2,...O)
    _, contiguous_orbit_membership = np.unique(orbit_membership_list[1], return_inverse = True)

    orbit_membership = {node: contiguous_orbit_membership[i] for i,node in enumerate(orbit_membership_list[0])}

    aut_count = len(aut_group)

    ##### induced edge automorphism orbits (according to the node automorphism group) #####
    edge_orbit_partition = dict()
    edge_orbit_membership = dict()
    edge_orbits2inds = dict()
    ind = 0
    
    edge_list = to_undirected(torch.tensor(graph.get_edges()))

    # infer edge automorphisms from the node automorphisms
    for i,edge in enumerate(edge_list):
        edge_orbit = frozenset([orbit_membership[edge[0]], orbit_membership[edge[1]]])
        if edge_orbit not in edge_orbits2inds:
            edge_orbits2inds[edge_orbit] = ind
            ind_edge_orbit = ind
            ind += 1
        else:
            ind_edge_orbit = edge_orbits2inds[edge_orbit]

        if ind_edge_orbit not in edge_orbit_partition:
            edge_orbit_partition[ind_edge_orbit] = [tuple(edge)]
        else:
            edge_orbit_partition[ind_edge_orbit] += [tuple(edge)] 

        edge_orbit_membership[i] = ind_edge_orbit

    print('Edge orbit partition of given substructure: {}'.format(edge_orbit_partition)) 
    print('Number of edge orbits: {}'.format(len(edge_orbit_partition)))
    print('Graph (node) automorphism count: {}'.format(aut_count))
    
    return graph, edge_orbit_partition, edge_orbit_membership, aut_count

def subgraph_isomorphism_edge_counts(edge_index, subgraph_dict):
    
    ##### edge structural identifiers #####
    
    edge_index = edge_index.transpose(1,0).cpu().numpy()
    edge_dict = {}
    for i, edge in enumerate(edge_index):         
        edge_dict[tuple(edge)] = i
        
    subgraph_edges = to_undirected(torch.tensor(subgraph_dict['subgraph'].get_edges().tolist()))

    G_gt = gt.Graph(directed=False)
    G_gt.add_edge_list(list(edge_index))
    gt.stats.remove_self_loops(G_gt)
    gt.stats.remove_parallel_edges(G_gt)  
       
    # compute all subgraph isomorphisms
    sub_iso = gt_topology.subgraph_isomorphism(subgraph_dict['subgraph'], G_gt, induced=True, subgraph=True, generator=True)
    
    counts = np.zeros((edge_index.shape[0], len(subgraph_dict['orbit_partition'])))
    
    for sub_iso_curr in sub_iso:
        mapping = sub_iso_curr.get_array()
        for i,edge in enumerate(subgraph_edges): 
            
            # for every edge in the graph H, find the edge in the subgraph G_S to which it is mapped
            # (by finding where its endpoints are matched). 
            # Then, increase the count of the matched edge w.r.t. the corresponding orbit
            # Repeat for the reverse edge (the one with the opposite direction)
            
            edge_orbit = subgraph_dict['orbit_membership'][i]
            mapped_edge = tuple([mapping[edge[0]], mapping[edge[1]]])
            counts[edge_dict[mapped_edge], edge_orbit] += 1
            
    counts = counts/subgraph_dict['aut_count']
    
    counts = torch.tensor(counts)
    
    return counts

def prepare_dataset(name):
    
    # maximum size of cycle graph
    k = 8

    path = os.path.join('./', 'dataset', name)
    data_folder = os.path.join(path, 'processed')
    os.makedirs(data_folder, exist_ok=True)
    
    data_file = os.path.join(data_folder, 'cycle_graph_induced_{}.bin'.format(k))
        
    # try to load
    if os.path.exists(data_file):  # load
        print("Loading dataset from {}".format(data_file))
        g_list, split_idx = load_graphs(data_file)
    else: # generate
        g_list, split_idx = generate_dataset(path, name)
        print("Saving dataset to {}".format(data_file))
        save_graphs(data_file, g_list, split_idx)

    return g_list, split_idx

def generate_dataset(path, name):

    ### compute the orbits of each substructure in the list, as well as the node automorphism count
    subgraph_dicts = []

    edge_lists = []
    for k in range(3, 8 + 1):
        graphs_nx = nx.cycle_graph(k)
        edge_lists.append(list(graphs_nx.edges))

    for edge_list in edge_lists:
        subgraph, orbit_partition, orbit_membership, aut_count = induced_edge_automorphism_orbits(edge_list=edge_list)
        subgraph_dicts.append({'subgraph':subgraph, 'orbit_partition': orbit_partition, 
                               'orbit_membership': orbit_membership, 'aut_count': aut_count})
        
    ### load and preprocess dataset
    dataset = DglGraphPropPredDataset(name=name, root=path)
    split_idx = dataset.get_idx_split()
        
    # computation of subgraph isomorphisms & creation of data structure
    graphs_dgl = list()
    split_idx["label"] = []
    for i, datapoint in tqdm(enumerate(dataset)):
        g, label = datapoint
        g = _prepare(g, subgraph_dicts)
        graphs_dgl.append(g)
        split_idx["label"].append(label)

    split_idx["label"] = torch.stack(split_idx["label"])

    return graphs_dgl, split_idx
        
def _prepare(g, subgraph_dicts):

    edge_index = torch.stack(g.edges())
    
    identifiers = None
    for subgraph_dict in subgraph_dicts:
        counts = subgraph_isomorphism_edge_counts(edge_index, subgraph_dict)
        identifiers = counts if identifiers is None else torch.cat((identifiers, counts),1) 

    g.edata['subgraph_counts'] = identifiers.long()
    
    return g

if __name__ == '__main__':
    prepare_dataset("ogbg-molpcba")