Unverified Commit 480a4ae3 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] pass ntype/etype into partition book when node/edge_split (#3828)

* [BugFix] pass ntype/etype into partition book when node/edge_split

* fix test failure

* fix test failue on mxnet

* fix test failure
parent 57d2f31f
......@@ -1511,7 +1511,7 @@ def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=Tru
num_client_per_part, client_id_in_part)
else:
# Get all nodes that belong to the rank.
local_nids = partition_book.partid2nids(partition_book.partid)
local_nids = partition_book.partid2nids(partition_book.partid, ntype=ntype)
return _split_local(partition_book, rank, nodes, local_nids)
def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=True,
......@@ -1591,7 +1591,7 @@ def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=Tru
num_client_per_part, client_id_in_part)
else:
# Get all edges that belong to the rank.
local_eids = partition_book.partid2eids(partition_book.partid)
local_eids = partition_book.partid2eids(partition_book.partid, etype=etype)
return _split_local(partition_book, rank, edges, local_eids)
rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)
......@@ -19,6 +19,7 @@ import math
import unittest
import pickle
from utils import reset_envs, generate_ip_config
import pytest
if os.name != 'nt':
import fcntl
......@@ -642,14 +643,22 @@ def test_standalone_node_emb():
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split():
@pytest.mark.parametrize("hetero", [True, False])
def test_split(hetero):
if hetero:
g = create_random_hetero()
ntype = 'n1'
etype = 'r1'
else:
g = create_random_graph(10000)
ntype = '_N'
etype = '_E'
num_parts = 4
num_hops = 2
partition_graph(g, 'dist_graph_test', num_parts, '/tmp/dist_graph', num_hops=num_hops, part_method='metis')
node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30
edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30
node_mask = np.random.randint(0, 100, size=g.number_of_nodes(ntype)) > 30
edge_mask = np.random.randint(0, 100, size=g.number_of_edges(etype)) > 30
selected_nodes = np.nonzero(node_mask)[0]
selected_edges = np.nonzero(edge_mask)[0]
......@@ -666,32 +675,40 @@ def test_split():
part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
nodes1 = np.intersect1d(selected_nodes, F.asnumpy(local_nids))
nodes2 = node_split(node_mask, gpb, rank=i, force_even=False)
assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
if hetero:
ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
else:
local_nids = F.asnumpy(local_nids)
for n in nodes1:
nodes1 = np.intersect1d(selected_nodes, local_nids)
nodes2 = node_split(node_mask, gpb, ntype=ntype, rank=i, force_even=False)
assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
for n in F.asnumpy(nodes2):
assert n in local_nids
set_roles(num_parts * 2)
nodes3 = node_split(node_mask, gpb, rank=i * 2, force_even=False)
nodes4 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=False)
nodes3 = node_split(node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False)
nodes4 = node_split(node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False)
nodes5 = F.cat([nodes3, nodes4], 0)
assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))
set_roles(num_parts)
local_eids = F.nonzero_1d(part_g.edata['inner_edge'])
local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
edges1 = np.intersect1d(selected_edges, F.asnumpy(local_eids))
edges2 = edge_split(edge_mask, gpb, rank=i, force_even=False)
assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
if hetero:
etype_ids, eids = gpb.map_to_per_etype(local_eids)
local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
else:
local_eids = F.asnumpy(local_eids)
for e in edges1:
edges1 = np.intersect1d(selected_edges, local_eids)
edges2 = edge_split(edge_mask, gpb, etype=etype, rank=i, force_even=False)
assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
for e in F.asnumpy(edges2):
assert e in local_eids
set_roles(num_parts * 2)
edges3 = edge_split(edge_mask, gpb, rank=i * 2, force_even=False)
edges4 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=False)
edges3 = edge_split(edge_mask, gpb, etype=etype, rank=i * 2, force_even=False)
edges4 = edge_split(edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False)
edges5 = F.cat([edges3, edges4], 0)
assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))
......@@ -770,7 +787,8 @@ if __name__ == '__main__':
os.makedirs('/tmp/dist_graph', exist_ok=True)
test_dist_emb_server_client()
test_server_client()
test_split()
test_split(True)
test_split(False)
test_split_even()
test_standalone()
test_standalone_node_emb()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment