"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "8d710c82b6e1b9cd07254e23555b81c8a45ac014"
Unverified Commit 4b96addc authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] Fix bugs in range partitioning. (#1703)

* Fix bugs in partitioning.

* add comment.

* fix lint.

* fix test.

* add more tests.

* simplify the fix

* fix.
parent b742c559
......@@ -83,6 +83,7 @@ import numpy as np
from .. import backend as F
from ..base import NID, EID
from ..random import choice as random_choice
from ..data.utils import load_graphs, save_graphs, load_tensors, save_tensors
from ..transform import metis_partition_assignment, partition_graph_with_halo
from .graph_partition_book import GraphPartitionBook, RangePartitionBook
......@@ -242,7 +243,7 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
node_parts = metis_partition_assignment(g, num_parts)
client_parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
elif part_method == 'random':
node_parts = dgl.random.choice(num_parts, g.number_of_nodes())
node_parts = random_choice(num_parts, g.number_of_nodes())
client_parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
else:
raise Exception('Unknown partitioning method: ' + part_method)
......
......@@ -619,7 +619,8 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
if reshuffle:
node_part = node_part.tousertensor()
sorted_part, new2old_map = F.sort_1d(node_part)
new_node_ids = F.gather_row(F.arange(0, g.number_of_nodes()), new2old_map)
new_node_ids = np.zeros((g.number_of_nodes(),), dtype=np.int64)
new_node_ids[F.asnumpy(new2old_map)] = np.arange(0, g.number_of_nodes())
g = reorder_nodes(g, new_node_ids)
node_part = utils.toindex(sorted_part)
# We reassign edges in in-CSR. In this way, after partitioning, we can ensure
......
......@@ -242,6 +242,12 @@ def test_partition_with_halo():
block_eids2 = F.asnumpy(F.gather_row(subg.parent_eid, block_eids2))
assert np.all(np.sort(block_eids1) == np.sort(block_eids2))
subgs = dgl.transform.partition_graph_with_halo(g, node_part, 2, reshuffle=True)
for part_id, subg in subgs.items():
node_ids = np.nonzero(node_part == part_id)[0]
lnode_ids = np.nonzero(F.asnumpy(subg.ndata['inner_node']))[0]
assert np.all(np.sort(F.asnumpy(subg.ndata['orig_id'])[lnode_ids]) == node_ids)
@unittest.skipIf(F._default_context_str == 'gpu', reason="METIS doesn't support GPU")
def test_metis_partition():
# TODO(zhengda) Metis fails to partition a small graph.
......@@ -322,6 +328,13 @@ def test_reorder_nodes():
assert np.all(F.asnumpy(new_in_deg == new_in_deg1))
assert np.all(F.asnumpy(new_out_deg == new_out_deg1))
orig_ids = F.asnumpy(new_g.ndata['orig_id'])
for nid in range(g.number_of_nodes()):
neighs = F.asnumpy(g.successors(nid))
new_neighs1 = new_nids[neighs]
new_nid = new_nids[nid]
new_neighs2 = new_g.successors(new_nid)
assert np.all(np.sort(new_neighs1) == np.sort(F.asnumpy(new_neighs2)))
for nid in range(new_g.number_of_nodes()):
neighs = F.asnumpy(new_g.successors(nid))
old_neighs1 = orig_ids[neighs]
......
......@@ -17,7 +17,7 @@ def create_random_graph(n):
ig = create_graph_index(arr, readonly=True)
return dgl.DGLGraph(ig)
def check_partition(reshuffle):
def check_partition(part_method, reshuffle):
g = create_random_graph(10000)
g.ndata['labels'] = F.arange(0, g.number_of_nodes())
g.ndata['feats'] = F.tensor(np.random.randn(g.number_of_nodes(), 10))
......@@ -28,7 +28,7 @@ def check_partition(reshuffle):
num_hops = 2
partition_graph(g, 'test', num_parts, '/tmp/partition', num_hops=num_hops,
part_method='metis', reshuffle=reshuffle)
part_method=part_method, reshuffle=reshuffle)
part_sizes = []
for i in range(num_parts):
part_g, node_feats, edge_feats, gpb = load_partition('/tmp/partition/test.json', i)
......@@ -105,8 +105,10 @@ def check_partition(reshuffle):
assert np.all(F.asnumpy(eid2pid) == edge_map)
def test_partition():
check_partition(True)
check_partition(False)
check_partition('metis', True)
check_partition('metis', False)
check_partition('random', True)
check_partition('random', False)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment