verify_mag_partitions.py 4.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import json
import numpy as np
import dgl
import torch as th
from ogb.nodeproppred import DglNodePropPredDataset

with open('outputs/mag.json') as json_file:
    metadata = json.load(json_file)
num_parts = metadata['num_parts']

# Load OGB-MAG.
dataset = DglNodePropPredDataset(name='ogbn-mag')
hg_orig, labels = dataset[0]
subgs = {}
for etype in hg_orig.canonical_etypes:
    u, v = hg_orig.all_edges(etype=etype)
    subgs[etype] = (u, v)
    subgs[(etype[2], 'rev-'+etype[1], etype[0])] = (v, u)
hg = dgl.heterograph(subgs)
hg.nodes['paper'].data['feat'] = hg_orig.nodes['paper'].data['feat']

# Construct node data and edge data after reshuffling.
node_feats = {}
edge_feats = {}
for partid in range(num_parts):
    part_node_feats = dgl.data.utils.load_tensors('outputs/part{}/node_feat.dgl'.format(partid))
    part_edge_feats = dgl.data.utils.load_tensors('outputs/part{}/edge_feat.dgl'.format(partid))
    for key in part_node_feats:
        if key in node_feats:
            node_feats[key].append(part_node_feats[key])
        else:
            node_feats[key] = [part_node_feats[key]]
    for key in part_edge_feats:
        if key in edge_feats:
            edge_feats[key].append(part_edge_feats[key])
        else:
            edge_feats[key] = [part_edge_feats[key]]
for key in node_feats:
    node_feats[key] = th.cat(node_feats[key])
for key in edge_feats:
    edge_feats[key] = th.cat(edge_feats[key])

ntype_map = metadata['ntypes']
ntypes = [None] * len(ntype_map)
for key in ntype_map:
    ntype_id = ntype_map[key]
    ntypes[ntype_id] = key
etype_map = metadata['etypes']
etypes = [None] * len(etype_map)
for key in etype_map:
    etype_id = etype_map[key]
    etypes[etype_id] = key

node_map = metadata['node_map']
for key in node_map:
    node_map[key] = th.stack([th.tensor(row) for row in node_map[key]], 0)
nid_map = dgl.distributed.id_map.IdMap(node_map)
edge_map = metadata['edge_map']
for key in edge_map:
    edge_map[key] = th.stack([th.tensor(row) for row in edge_map[key]], 0)
eid_map = dgl.distributed.id_map.IdMap(edge_map)

# Load the graph partition structure.
for partid in range(num_parts):
    print('test part', partid)
    part_file = 'outputs/part{}/graph.dgl'.format(partid)
    subg = dgl.load_graphs(part_file)[0][0]
    subg_src_id, subg_dst_id = subg.edges()
    subg_src_id = subg.ndata['orig_id'][subg_src_id]
    subg_dst_id = subg.ndata['orig_id'][subg_dst_id]
    subg_ntype = subg.ndata[dgl.NTYPE]
    subg_etype = subg.edata[dgl.ETYPE]
    for ntype_id in th.unique(subg_ntype):
        ntype = ntypes[ntype_id]
        idx = subg_ntype == ntype_id
        # This is global IDs after reshuffle.
        nid = subg.ndata[dgl.NID][idx]
        ntype_ids1, type_nid = nid_map(nid)
        orig_type_nid = subg.ndata['orig_id'][idx]
        # All nodes should have the same node type.
        assert np.all(ntype_ids1.numpy() == int(ntype_id))

        # Check node data.
        for name in hg.nodes[ntype].data:
            local_data = node_feats[ntype + '/' + name][type_nid]
            local_data1 = hg.nodes[ntype].data[name][orig_type_nid]
            assert np.all(local_data.numpy() == local_data1.numpy())

    for etype_id in th.unique(subg_etype):
        etype = etypes[etype_id]
        idx = subg_etype == etype_id
        exist = hg[etype].has_edges_between(subg_src_id[idx], subg_dst_id[idx])
        assert np.all(exist.numpy())
        eid = hg[etype].edge_ids(subg_src_id[idx], subg_dst_id[idx])
        assert np.all(eid.numpy() == subg.edata['orig_id'][idx].numpy())

        # This is global IDs after reshuffle.
        eid = subg.edata[dgl.EID][idx]
        etype_ids1, type_eid = eid_map(eid)
        orig_type_eid = subg.edata['orig_id'][idx]
        # All edges should have the same edge type.
        assert np.all(etype_ids1.numpy() == int(etype_id))

        # Check edge data.
        for name in hg.edges[etype].data:
            local_data = edge_feats[etype + '/' + name][type_eid]
            local_data1 = hg.edges[etype].data[name][orig_type_eid]
            assert np.all(local_data.numpy() == local_data1.numpy())