Unverified Commit e7046f1e authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix a bug in graph partition. (#2869)



* update distributed training doc.

* explain data split.

* fix message passing.

* id mapping.

* fix.

* test data reshuffling.

* fix a bug.

* fix test.

* Revert "fix."

This reverts commit 2d025e9e1a5c05c3da9b803a035a788ced59bd77.

* Revert "id mapping."

This reverts commit 2a6a93ceb81fbdff86e6e9e5a58e1ace1e9d9882.

* Revert "fix message passing."

This reverts commit ed8a86bf2b015e5e4f64ba160e81b207ad2a1d65.

* Revert "explain data split."

This reverts commit 4338ddf8a336014cf92d4cb9a1db02b9badc0e55.

* Revert "update distributed training doc."

This reverts commit dda1c35c44536934c19715534f01f832afda6ad2.

* add more tests.

* fix.

* fix.

* fix.
Co-authored-by: default avatarZheng <dzzhen@3c22fba32af5.ant.amazon.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 2ae891a7
...@@ -7,6 +7,7 @@ from .heterograph import DGLHeteroGraph ...@@ -7,6 +7,7 @@ from .heterograph import DGLHeteroGraph
from . import backend as F from . import backend as F
from . import utils from . import utils
from .base import EID, NID, NTYPE, ETYPE from .base import EID, NID, NTYPE, ETYPE
from .subgraph import edge_subgraph
__all__ = ["metis_partition", "metis_partition_assignment", __all__ = ["metis_partition", "metis_partition_assignment",
"partition_graph_with_halo"] "partition_graph_with_halo"]
...@@ -173,8 +174,16 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False): ...@@ -173,8 +174,16 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
# This creaets a subgraph from subgraphs returned from the CAPI above. # This creaets a subgraph from subgraphs returned from the CAPI above.
def create_subgraph(subg, induced_nodes, induced_edges): def create_subgraph(subg, induced_nodes, induced_edges):
subg1 = DGLHeteroGraph(gidx=subg.graph, ntypes=['_N'], etypes=['_E']) subg1 = DGLHeteroGraph(gidx=subg.graph, ntypes=['_N'], etypes=['_E'])
subg1.ndata[NID] = induced_nodes[0] # If IDs are shuffled, we should shuffled edges. This will help us collect edge data
subg1.edata[EID] = induced_edges[0] # from the distributed graph after training.
if reshuffle:
sorted_edges, index = F.sort_1d(induced_edges[0])
subg1 = edge_subgraph(subg1, index, preserve_nodes=True)
subg1.ndata[NID] = induced_nodes[0]
subg1.edata[EID] = sorted_edges
else:
subg1.ndata[NID] = induced_nodes[0]
subg1.edata[EID] = induced_edges[0]
return subg1 return subg1
for i, subg in enumerate(subgs): for i, subg in enumerate(subgs):
......
...@@ -39,7 +39,7 @@ def create_random_graph(n): ...@@ -39,7 +39,7 @@ def create_random_graph(n):
return dgl.from_scipy(arr) return dgl.from_scipy(arr)
def create_random_hetero(): def create_random_hetero():
num_nodes = {'n1': 10000, 'n2': 10010, 'n3': 10020} num_nodes = {'n1': 1000, 'n2': 1010, 'n3': 1020}
etypes = [('n1', 'r1', 'n2'), etypes = [('n1', 'r1', 'n2'),
('n1', 'r2', 'n3'), ('n1', 'r2', 'n3'),
('n2', 'r3', 'n3')] ('n2', 'r3', 'n3')]
...@@ -120,24 +120,50 @@ def verify_hetero_graph(g, parts): ...@@ -120,24 +120,50 @@ def verify_hetero_graph(g, parts):
assert len(uniq_ids) == g.number_of_edges(etype) assert len(uniq_ids) == g.number_of_edges(etype)
# TODO(zhengda) this doesn't check 'part_id' # TODO(zhengda) this doesn't check 'part_id'
def verify_graph_feats(g, part, node_feats): def verify_graph_feats(g, gpb, part, node_feats, edge_feats):
for ntype in g.ntypes: for ntype in g.ntypes:
ntype_id = g.get_ntype_id(ntype) ntype_id = g.get_ntype_id(ntype)
inner_node_mask = _get_inner_node_mask(part, ntype_id)
inner_nids = F.boolean_mask(part.ndata[dgl.NID],inner_node_mask)
ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)
partid = gpb.nid2partid(inner_type_nids, ntype)
assert np.all(F.asnumpy(ntype_ids) == ntype_id)
assert np.all(F.asnumpy(partid) == gpb.partid)
orig_id = F.boolean_mask(part.ndata['orig_id'], inner_node_mask)
local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)
for name in g.nodes[ntype].data: for name in g.nodes[ntype].data:
if name in [dgl.NID, 'inner_node']: if name in [dgl.NID, 'inner_node']:
continue continue
inner_node_mask = _get_inner_node_mask(part, ntype_id)
inner_nids = F.boolean_mask(part.ndata[dgl.NID],inner_node_mask)
min_nids = F.min(inner_nids, 0)
orig_id = F.boolean_mask(part.ndata['orig_id'], inner_node_mask)
true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id)
ndata = F.gather_row(node_feats[ntype + '/' + name], inner_nids - min_nids) ndata = F.gather_row(node_feats[ntype + '/' + name], local_nids)
assert np.all(F.asnumpy(ndata == true_feats)) assert np.all(F.asnumpy(ndata == true_feats))
for etype in g.etypes:
etype_id = g.get_etype_id(etype)
inner_edge_mask = _get_inner_edge_mask(part, etype_id)
inner_eids = F.boolean_mask(part.edata[dgl.EID],inner_edge_mask)
etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)
partid = gpb.eid2partid(inner_type_eids, etype)
assert np.all(F.asnumpy(etype_ids) == etype_id)
assert np.all(F.asnumpy(partid) == gpb.partid)
orig_id = F.boolean_mask(part.edata['orig_id'], inner_edge_mask)
local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)
for name in g.edges[etype].data:
if name in [dgl.EID, 'inner_edge']:
continue
true_feats = F.gather_row(g.edges[etype].data[name], orig_id)
edata = F.gather_row(edge_feats[etype + '/' + name], local_eids)
assert np.all(F.asnumpy(edata == true_feats))
def check_hetero_partition(hg, part_method): def check_hetero_partition(hg, part_method):
hg.nodes['n1'].data['labels'] = F.arange(0, hg.number_of_nodes('n1')) hg.nodes['n1'].data['labels'] = F.arange(0, hg.number_of_nodes('n1'))
hg.nodes['n1'].data['feats'] = F.tensor(np.random.randn(hg.number_of_nodes('n1'), 10), F.float32) hg.nodes['n1'].data['feats'] = F.tensor(np.random.randn(hg.number_of_nodes('n1'), 10), F.float32)
hg.edges['r1'].data['feats'] = F.tensor(np.random.randn(hg.number_of_edges('r1'), 10), F.float32) hg.edges['r1'].data['feats'] = F.tensor(np.random.randn(hg.number_of_edges('r1'), 10), F.float32)
hg.edges['r1'].data['labels'] = F.arange(0, hg.number_of_edges('r1'))
num_parts = 4 num_parts = 4
num_hops = 1 num_hops = 1
...@@ -150,6 +176,8 @@ def check_hetero_partition(hg, part_method): ...@@ -150,6 +176,8 @@ def check_hetero_partition(hg, part_method):
for etype in hg.etypes: for etype in hg.etypes:
assert len(orig_eids[etype]) == hg.number_of_edges(etype) assert len(orig_eids[etype]) == hg.number_of_edges(etype)
parts = [] parts = []
shuffled_labels = []
shuffled_elabels = []
for i in range(num_parts): for i in range(num_parts):
part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i) part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i)
# Verify the mapping between the reshuffled IDs and the original IDs. # Verify the mapping between the reshuffled IDs and the original IDs.
...@@ -181,9 +209,21 @@ def check_hetero_partition(hg, part_method): ...@@ -181,9 +209,21 @@ def check_hetero_partition(hg, part_method):
assert len(orig_eids1) == len(orig_eids2) assert len(orig_eids1) == len(orig_eids2)
assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))
parts.append(part_g) parts.append(part_g)
verify_graph_feats(hg, part_g, node_feats) verify_graph_feats(hg, gpb, part_g, node_feats, edge_feats)
shuffled_labels.append(node_feats['n1/labels'])
shuffled_elabels.append(edge_feats['r1/labels'])
verify_hetero_graph(hg, parts) verify_hetero_graph(hg, parts)
shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0))
orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)
orig_elabels = np.zeros(shuffled_elabels.shape, dtype=shuffled_elabels.dtype)
orig_labels[F.asnumpy(orig_nids['n1'])] = shuffled_labels
orig_elabels[F.asnumpy(orig_eids['r1'])] = shuffled_elabels
assert np.all(orig_labels == F.asnumpy(hg.nodes['n1'].data['labels']))
assert np.all(orig_elabels == F.asnumpy(hg.edges['r1'].data['labels']))
def check_partition(g, part_method, reshuffle): def check_partition(g, part_method, reshuffle):
g.ndata['labels'] = F.arange(0, g.number_of_nodes()) g.ndata['labels'] = F.arange(0, g.number_of_nodes())
g.ndata['feats'] = F.tensor(np.random.randn(g.number_of_nodes(), 10), F.float32) g.ndata['feats'] = F.tensor(np.random.randn(g.number_of_nodes(), 10), F.float32)
...@@ -196,6 +236,8 @@ def check_partition(g, part_method, reshuffle): ...@@ -196,6 +236,8 @@ def check_partition(g, part_method, reshuffle):
orig_nids, orig_eids = partition_graph(g, 'test', num_parts, '/tmp/partition', num_hops=num_hops, orig_nids, orig_eids = partition_graph(g, 'test', num_parts, '/tmp/partition', num_hops=num_hops,
part_method=part_method, reshuffle=reshuffle, return_mapping=True) part_method=part_method, reshuffle=reshuffle, return_mapping=True)
part_sizes = [] part_sizes = []
shuffled_labels = []
shuffled_edata = []
for i in range(num_parts): for i in range(num_parts):
part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i) part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i)
...@@ -251,6 +293,7 @@ def check_partition(g, part_method, reshuffle): ...@@ -251,6 +293,7 @@ def check_partition(g, part_method, reshuffle):
else: else:
part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], part_g.ndata[dgl.NID]) part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], part_g.ndata[dgl.NID])
part_g.edata['feats'] = F.gather_row(g.edata['feats'], part_g.edata[dgl.NID]) part_g.edata['feats'] = F.gather_row(g.edata['feats'], part_g.edata[dgl.NID])
part_g.update_all(fn.copy_src('feats', 'msg'), fn.sum('msg', 'h')) part_g.update_all(fn.copy_src('feats', 'msg'), fn.sum('msg', 'h'))
part_g.update_all(fn.copy_edge('feats', 'msg'), fn.sum('msg', 'eh')) part_g.update_all(fn.copy_edge('feats', 'msg'), fn.sum('msg', 'eh'))
assert F.allclose(F.gather_row(g.ndata['h'], local_nodes), assert F.allclose(F.gather_row(g.ndata['h'], local_nodes),
...@@ -261,11 +304,31 @@ def check_partition(g, part_method, reshuffle): ...@@ -261,11 +304,31 @@ def check_partition(g, part_method, reshuffle):
for name in ['labels', 'feats']: for name in ['labels', 'feats']:
assert '_N/' + name in node_feats assert '_N/' + name in node_feats
assert node_feats['_N/' + name].shape[0] == len(local_nodes) assert node_feats['_N/' + name].shape[0] == len(local_nodes)
assert np.all(F.asnumpy(g.ndata[name])[F.asnumpy(local_nodes)] == F.asnumpy(node_feats['_N/' + name])) true_feats = F.gather_row(g.ndata[name], local_nodes)
ndata = F.gather_row(node_feats['_N/' + name], local_nid)
assert np.all(F.asnumpy(true_feats) == F.asnumpy(ndata))
for name in ['feats']: for name in ['feats']:
assert '_E/' + name in edge_feats assert '_E/' + name in edge_feats
assert edge_feats['_E/' + name].shape[0] == len(local_edges) assert edge_feats['_E/' + name].shape[0] == len(local_edges)
assert np.all(F.asnumpy(g.edata[name])[F.asnumpy(local_edges)] == F.asnumpy(edge_feats['_E/' + name])) true_feats = F.gather_row(g.edata[name], local_edges)
edata = F.gather_row(edge_feats['_E/' + name], local_eid)
assert np.all(F.asnumpy(true_feats) == F.asnumpy(edata))
# This only works if node/edge IDs are shuffled.
if reshuffle:
shuffled_labels.append(node_feats['_N/labels'])
shuffled_edata.append(edge_feats['_E/feats'])
# Verify that we can reconstruct node/edge data for original IDs.
if reshuffle:
shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))
shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0))
orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)
orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype)
orig_labels[F.asnumpy(orig_nids)] = shuffled_labels
orig_edata[F.asnumpy(orig_eids)] = shuffled_edata
assert np.all(orig_labels == F.asnumpy(g.ndata['labels']))
assert np.all(orig_edata == F.asnumpy(g.edata['feats']))
if reshuffle: if reshuffle:
node_map = [] node_map = []
...@@ -284,7 +347,7 @@ def check_partition(g, part_method, reshuffle): ...@@ -284,7 +347,7 @@ def check_partition(g, part_method, reshuffle):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_partition(): def test_partition():
g = create_random_graph(10000) g = create_random_graph(1000)
check_partition(g, 'metis', False) check_partition(g, 'metis', False)
check_partition(g, 'metis', True) check_partition(g, 'metis', True)
check_partition(g, 'random', False) check_partition(g, 'random', False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment