"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8d36d5adb1edb8eaaa40a29ef5510f51c503f19e"
Unverified Commit 1e886cb9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistDGL] sort node/edge_map to obtain expected id ranges (#5872)

parent 921476c8
......@@ -74,7 +74,7 @@ def _dump_part_config(part_config, part_metadata):
"""Format and dump part config."""
part_metadata = _format_part_metadata(part_metadata, _etype_tuple_to_str)
with open(part_config, "w") as outfile:
json.dump(part_metadata, outfile, sort_keys=True, indent=4)
json.dump(part_metadata, outfile, sort_keys=False, indent=4)
def _save_graphs(filename, g_list, formats=None, sort_etypes=False):
......@@ -420,6 +420,24 @@ def load_partition_book(part_config, part_id):
node_map = _get_part_ranges(node_map)
edge_map = _get_part_ranges(edge_map)
# Sort the node/edge maps by the node/edge type ID.
node_map = dict(sorted(node_map.items(), key=lambda x: ntypes[x[0]]))
edge_map = dict(sorted(edge_map.items(), key=lambda x: etypes[x[0]]))
def _assert_is_sorted(id_map):
id_ranges = np.array(list(id_map.values()))
ids = []
for i in range(num_parts):
ids.append(id_ranges[:, i, :])
ids = np.array(ids).flatten()
assert np.all(
ids[:-1] <= ids[1:]
), f"The node/edge map is not sorted: {ids}"
_assert_is_sorted(node_map)
_assert_is_sorted(edge_map)
return (
RangePartitionBook(
part_id, num_parts, node_map, edge_map, ntypes, etypes
......
......@@ -737,3 +737,114 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
assert g.get_etype_id(edge_type) == type_id
assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)
def test_not_sorted_node_edge_map():
# Partition configure file which includes not sorted node/edge map.
part_config_str = """
{
"edge_map": {
"item:likes-rev:user": [
[
0,
100
],
[
1000,
1500
]
],
"user:follows-rev:user": [
[
300,
600
],
[
2100,
2800
]
],
"user:follows:user": [
[
100,
300
],
[
1500,
2100
]
],
"user:likes:item": [
[
600,
1000
],
[
2800,
3600
]
]
},
"etypes": {
"item:likes-rev:user": 0,
"user:follows-rev:user": 2,
"user:follows:user": 1,
"user:likes:item": 3
},
"graph_name": "test_graph",
"halo_hops": 1,
"node_map": {
"user": [
[
100,
300
],
[
600,
1000
]
],
"item": [
[
0,
100
],
[
300,
600
]
]
},
"ntypes": {
"user": 1,
"item": 0
},
"num_edges": 3600,
"num_nodes": 1000,
"num_parts": 2,
"part-0": {
"edge_feats": "part0/edge_feat.dgl",
"node_feats": "part0/node_feat.dgl",
"part_graph": "part0/graph.dgl"
},
"part-1": {
"edge_feats": "part1/edge_feat.dgl",
"node_feats": "part1/node_feat.dgl",
"part_graph": "part1/graph.dgl"
},
"part_method": "metis"
}
"""
with tempfile.TemporaryDirectory() as test_dir:
part_config = os.path.join(test_dir, "test_graph.json")
print(part_config)
with open(part_config, "w") as file:
file.write(part_config_str)
# Part 0.
gpb, _, _, _ = load_partition_book(part_config, 0)
assert gpb.local_ntype_offset == [0, 100, 300]
assert gpb.local_etype_offset == [0, 100, 300, 600, 1000]
# Patr 1.
gpb, _, _, _ = load_partition_book(part_config, 1)
assert gpb.local_ntype_offset == [0, 300, 700]
assert gpb.local_etype_offset == [0, 500, 1100, 1800, 2600]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment