Unverified Commit 480a4ae3 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] pass ntype/etype into partition book when node/edge_split (#3828)

* [BugFix] pass ntype/etype into partition book when node/edge_split

* fix test failure

* fix test failue on mxnet

* fix test failure
parent 57d2f31f
...@@ -1511,7 +1511,7 @@ def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=Tru ...@@ -1511,7 +1511,7 @@ def node_split(nodes, partition_book=None, ntype='_N', rank=None, force_even=Tru
num_client_per_part, client_id_in_part) num_client_per_part, client_id_in_part)
else: else:
# Get all nodes that belong to the rank. # Get all nodes that belong to the rank.
local_nids = partition_book.partid2nids(partition_book.partid) local_nids = partition_book.partid2nids(partition_book.partid, ntype=ntype)
return _split_local(partition_book, rank, nodes, local_nids) return _split_local(partition_book, rank, nodes, local_nids)
def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=True, def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=True,
...@@ -1591,7 +1591,7 @@ def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=Tru ...@@ -1591,7 +1591,7 @@ def edge_split(edges, partition_book=None, etype='_E', rank=None, force_even=Tru
num_client_per_part, client_id_in_part) num_client_per_part, client_id_in_part)
else: else:
# Get all edges that belong to the rank. # Get all edges that belong to the rank.
local_eids = partition_book.partid2eids(partition_book.partid) local_eids = partition_book.partid2eids(partition_book.partid, etype=etype)
return _split_local(partition_book, rank, edges, local_eids) return _split_local(partition_book, rank, edges, local_eids)
rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse) rpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)
...@@ -19,6 +19,7 @@ import math ...@@ -19,6 +19,7 @@ import math
import unittest import unittest
import pickle import pickle
from utils import reset_envs, generate_ip_config from utils import reset_envs, generate_ip_config
import pytest
if os.name != 'nt': if os.name != 'nt':
import fcntl import fcntl
...@@ -642,14 +643,22 @@ def test_standalone_node_emb(): ...@@ -642,14 +643,22 @@ def test_standalone_node_emb():
dgl.distributed.exit_client() # this is needed since there's two test here in one process dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split(): @pytest.mark.parametrize("hetero", [True, False])
g = create_random_graph(10000) def test_split(hetero):
if hetero:
g = create_random_hetero()
ntype = 'n1'
etype = 'r1'
else:
g = create_random_graph(10000)
ntype = '_N'
etype = '_E'
num_parts = 4 num_parts = 4
num_hops = 2 num_hops = 2
partition_graph(g, 'dist_graph_test', num_parts, '/tmp/dist_graph', num_hops=num_hops, part_method='metis') partition_graph(g, 'dist_graph_test', num_parts, '/tmp/dist_graph', num_hops=num_hops, part_method='metis')
node_mask = np.random.randint(0, 100, size=g.number_of_nodes()) > 30 node_mask = np.random.randint(0, 100, size=g.number_of_nodes(ntype)) > 30
edge_mask = np.random.randint(0, 100, size=g.number_of_edges()) > 30 edge_mask = np.random.randint(0, 100, size=g.number_of_edges(etype)) > 30
selected_nodes = np.nonzero(node_mask)[0] selected_nodes = np.nonzero(node_mask)[0]
selected_edges = np.nonzero(edge_mask)[0] selected_edges = np.nonzero(edge_mask)[0]
...@@ -666,32 +675,40 @@ def test_split(): ...@@ -666,32 +675,40 @@ def test_split():
part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i) part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition('/tmp/dist_graph/dist_graph_test.json', i)
local_nids = F.nonzero_1d(part_g.ndata['inner_node']) local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids) local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
nodes1 = np.intersect1d(selected_nodes, F.asnumpy(local_nids)) if hetero:
nodes2 = node_split(node_mask, gpb, rank=i, force_even=False) ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
else:
local_nids = F.asnumpy(local_nids)
nodes1 = np.intersect1d(selected_nodes, local_nids)
nodes2 = node_split(node_mask, gpb, ntype=ntype, rank=i, force_even=False)
assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2))) assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
local_nids = F.asnumpy(local_nids) for n in F.asnumpy(nodes2):
for n in nodes1:
assert n in local_nids assert n in local_nids
set_roles(num_parts * 2) set_roles(num_parts * 2)
nodes3 = node_split(node_mask, gpb, rank=i * 2, force_even=False) nodes3 = node_split(node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False)
nodes4 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=False) nodes4 = node_split(node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False)
nodes5 = F.cat([nodes3, nodes4], 0) nodes5 = F.cat([nodes3, nodes4], 0)
assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5))) assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))
set_roles(num_parts) set_roles(num_parts)
local_eids = F.nonzero_1d(part_g.edata['inner_edge']) local_eids = F.nonzero_1d(part_g.edata['inner_edge'])
local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids) local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
edges1 = np.intersect1d(selected_edges, F.asnumpy(local_eids)) if hetero:
edges2 = edge_split(edge_mask, gpb, rank=i, force_even=False) etype_ids, eids = gpb.map_to_per_etype(local_eids)
local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]
else:
local_eids = F.asnumpy(local_eids)
edges1 = np.intersect1d(selected_edges, local_eids)
edges2 = edge_split(edge_mask, gpb, etype=etype, rank=i, force_even=False)
assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2))) assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
local_eids = F.asnumpy(local_eids) for e in F.asnumpy(edges2):
for e in edges1:
assert e in local_eids assert e in local_eids
set_roles(num_parts * 2) set_roles(num_parts * 2)
edges3 = edge_split(edge_mask, gpb, rank=i * 2, force_even=False) edges3 = edge_split(edge_mask, gpb, etype=etype, rank=i * 2, force_even=False)
edges4 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=False) edges4 = edge_split(edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False)
edges5 = F.cat([edges3, edges4], 0) edges5 = F.cat([edges3, edges4], 0)
assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5))) assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))
...@@ -770,7 +787,8 @@ if __name__ == '__main__': ...@@ -770,7 +787,8 @@ if __name__ == '__main__':
os.makedirs('/tmp/dist_graph', exist_ok=True) os.makedirs('/tmp/dist_graph', exist_ok=True)
test_dist_emb_server_client() test_dist_emb_server_client()
test_server_client() test_server_client()
test_split() test_split(True)
test_split(False)
test_split_even() test_split_even()
test_standalone() test_standalone()
test_standalone_node_emb() test_standalone_node_emb()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment