test_partition.py 5.62 KB
Newer Older
1
2
import dgl
import sys
3
import os
4
5
6
7
8
import numpy as np
from scipy import sparse as spsp
from numpy.testing import assert_array_equal
from dgl.graph_index import create_graph_index
from dgl.distributed import partition_graph, load_partition
9
from dgl import function as fn
10
11
12
import backend as F
import unittest
import pickle
13
import random
14
15

def create_random_graph(n):
Jinjing Zhou's avatar
Jinjing Zhou committed
16
    arr = (spsp.random(n, n, density=0.001, format='coo', random_state=100) != 0).astype(np.int64)
17
18
19
    ig = create_graph_index(arr, readonly=True)
    return dgl.DGLGraph(ig)

20
def check_partition(g, part_method, reshuffle):
21
22
    g.ndata['labels'] = F.arange(0, g.number_of_nodes())
    g.ndata['feats'] = F.tensor(np.random.randn(g.number_of_nodes(), 10))
23
24
25
    g.edata['feats'] = F.tensor(np.random.randn(g.number_of_edges(), 10))
    g.update_all(fn.copy_src('feats', 'msg'), fn.sum('msg', 'h'))
    g.update_all(fn.copy_edge('feats', 'msg'), fn.sum('msg', 'eh'))
26
27
    num_parts = 4
    num_hops = 2
Da Zheng's avatar
Da Zheng committed
28
29

    partition_graph(g, 'test', num_parts, '/tmp/partition', num_hops=num_hops,
30
                    part_method=part_method, reshuffle=reshuffle)
Da Zheng's avatar
Da Zheng committed
31
    part_sizes = []
32
    for i in range(num_parts):
33
        part_g, node_feats, edge_feats, gpb, _ = load_partition('/tmp/partition/test.json', i)
34
35

        # Check the metadata
Da Zheng's avatar
Da Zheng committed
36
37
38
39
40
41
42
43
44
45
46
        assert gpb._num_nodes() == g.number_of_nodes()
        assert gpb._num_edges() == g.number_of_edges()

        assert gpb.num_partitions() == num_parts
        gpb_meta = gpb.metadata()
        assert len(gpb_meta) == num_parts
        assert len(gpb.partid2nids(i)) == gpb_meta[i]['num_nodes']
        assert len(gpb.partid2eids(i)) == gpb_meta[i]['num_edges']
        part_sizes.append((gpb_meta[i]['num_nodes'], gpb_meta[i]['num_edges']))

        local_nid = gpb.nid2localnid(F.boolean_mask(part_g.ndata[dgl.NID], part_g.ndata['inner_node']), i)
47
        assert F.dtype(local_nid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
48
49
        assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid)))
        local_eid = gpb.eid2localeid(F.boolean_mask(part_g.edata[dgl.EID], part_g.edata['inner_edge']), i)
50
        assert F.dtype(local_eid) in (F.int64, F.int32)
Da Zheng's avatar
Da Zheng committed
51
        assert np.all(F.asnumpy(local_eid) == np.arange(0, len(local_eid)))
52
53

        # Check the node map.
54
55
56
        local_nodes = F.boolean_mask(part_g.ndata[dgl.NID], part_g.ndata['inner_node'])
        llocal_nodes = F.nonzero_1d(part_g.ndata['inner_node'])
        local_nodes1 = gpb.partid2nids(i)
57
        assert F.dtype(local_nodes1) in (F.int32, F.int64)
58
        assert np.all(np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(local_nodes1)))
59
60

        # Check the edge map.
61
62
        local_edges = F.boolean_mask(part_g.edata[dgl.EID], part_g.edata['inner_edge'])
        local_edges1 = gpb.partid2eids(i)
63
        assert F.dtype(local_edges1) in (F.int32, F.int64)
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        assert np.all(np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(local_edges1)))

        if reshuffle:
            part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], part_g.ndata['orig_id'])
            part_g.edata['feats'] = F.gather_row(g.edata['feats'], part_g.edata['orig_id'])
            # when we read node data from the original global graph, we should use orig_id.
            local_nodes = F.boolean_mask(part_g.ndata['orig_id'], part_g.ndata['inner_node'])
            local_edges = F.boolean_mask(part_g.edata['orig_id'], part_g.edata['inner_edge'])
        else:
            part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], part_g.ndata[dgl.NID])
            part_g.edata['feats'] = F.gather_row(g.edata['feats'], part_g.edata[dgl.NID])
        part_g.update_all(fn.copy_src('feats', 'msg'), fn.sum('msg', 'h'))
        part_g.update_all(fn.copy_edge('feats', 'msg'), fn.sum('msg', 'eh'))
        assert F.allclose(F.gather_row(g.ndata['h'], local_nodes),
                          F.gather_row(part_g.ndata['h'], llocal_nodes))
        assert F.allclose(F.gather_row(g.ndata['eh'], local_nodes),
                          F.gather_row(part_g.ndata['eh'], llocal_nodes))
81
82
83
84

        for name in ['labels', 'feats']:
            assert name in node_feats
            assert node_feats[name].shape[0] == len(local_nodes)
85
86
87
88
89
            assert np.all(F.asnumpy(g.ndata[name])[F.asnumpy(local_nodes)] == F.asnumpy(node_feats[name]))
        for name in ['feats']:
            assert name in edge_feats
            assert edge_feats[name].shape[0] == len(local_edges)
            assert np.all(F.asnumpy(g.edata[name])[F.asnumpy(local_edges)] == F.asnumpy(edge_feats[name]))
90

Da Zheng's avatar
Da Zheng committed
91
92
93
94
95
96
97
98
    if reshuffle:
        node_map = []
        edge_map = []
        for i, (num_nodes, num_edges) in enumerate(part_sizes):
            node_map.append(np.ones(num_nodes) * i)
            edge_map.append(np.ones(num_edges) * i)
        node_map = np.concatenate(node_map)
        edge_map = np.concatenate(edge_map)
99
100
101
102
103
104
        nid2pid = gpb.nid2partid(F.arange(0, len(node_map)))
        assert F.dtype(nid2pid) in (F.int32, F.int64)
        assert np.all(F.asnumpy(nid2pid) == node_map)
        eid2pid = gpb.eid2partid(F.arange(0, len(edge_map)))
        assert F.dtype(eid2pid) in (F.int32, F.int64)
        assert np.all(F.asnumpy(eid2pid) == edge_map)
Da Zheng's avatar
Da Zheng committed
105
106

def test_partition():
107
108
109
110
111
112
113
114
115
116
117
118
119
    g = create_random_graph(10000)
    check_partition(g, 'metis', True)
    check_partition(g, 'metis', False)
    check_partition(g, 'random', True)
    check_partition(g, 'random', False)

def test_hetero_partition():
    g = create_random_graph(10000)
    g = dgl.as_heterograph(g)
    check_partition(g, 'metis', True)
    check_partition(g, 'metis', False)
    check_partition(g, 'random', True)
    check_partition(g, 'random', False)
Da Zheng's avatar
Da Zheng committed
120

121
122

if __name__ == '__main__':
Da Zheng's avatar
Da Zheng committed
123
    os.makedirs('/tmp/partition', exist_ok=True)
124
    test_partition()
125
    test_hetero_partition()