Unverified Commit 366fb02f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable to load GB partitions (#7029)

parent 523bbb4c
...@@ -143,7 +143,58 @@ def _get_part_ranges(id_ranges): ...@@ -143,7 +143,58 @@ def _get_part_ranges(id_ranges):
return res return res
def load_partition(part_config, part_id, load_feats=True): def _verify_dgl_partition(graph, part_id, gpb, ntypes, etypes):
"""Verify the partition of a DGL graph."""
assert (
NID in graph.ndata
), "the partition graph should contain node mapping to global node ID"
assert (
EID in graph.edata
), "the partition graph should contain edge mapping to global edge ID"
for ntype in ntypes:
ntype_id = ntypes[ntype]
# graph.ndata[NID] are global homogeneous node IDs.
nids = F.boolean_mask(
graph.ndata[NID], _get_inner_node_mask(graph, ntype_id)
)
partids1 = gpb.nid2partid(nids)
_, per_type_nids = gpb.map_to_per_ntype(nids)
partids2 = gpb.nid2partid(per_type_nids, ntype)
assert np.all(F.asnumpy(partids1 == part_id)), (
"Unexpected partition IDs are found in the loaded partition "
"while querying via global homogeneous node IDs."
)
assert np.all(F.asnumpy(partids2 == part_id)), (
"Unexpected partition IDs are found in the loaded partition "
"while querying via type-wise node IDs."
)
for etype in etypes:
etype_id = etypes[etype]
# graph.edata[EID] are global homogeneous edge IDs.
eids = F.boolean_mask(
graph.edata[EID], _get_inner_edge_mask(graph, etype_id)
)
partids1 = gpb.eid2partid(eids)
_, per_type_eids = gpb.map_to_per_etype(eids)
partids2 = gpb.eid2partid(per_type_eids, etype)
assert np.all(F.asnumpy(partids1 == part_id)), (
"Unexpected partition IDs are found in the loaded partition "
"while querying via global homogeneous edge IDs."
)
assert np.all(F.asnumpy(partids2 == part_id)), (
"Unexpected partition IDs are found in the loaded partition "
"while querying via type-wise edge IDs."
)
def _verify_graphbolt_partition(graph, part_id, gpb, ntypes, etypes):
"""Verify the partition of a GraphBolt graph."""
# [Rui][TODO]
_, _, _, _, _ = graph, part_id, gpb, ntypes, etypes
def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False):
"""Load data of a partition from the data path. """Load data of a partition from the data path.
A partition data includes a graph structure of the partition, a dict of node tensors, A partition data includes a graph structure of the partition, a dict of node tensors,
...@@ -165,6 +216,8 @@ def load_partition(part_config, part_id, load_feats=True): ...@@ -165,6 +216,8 @@ def load_partition(part_config, part_id, load_feats=True):
load_feats : bool, optional load_feats : bool, optional
Whether to load node/edge feats. If False, the returned node/edge feature Whether to load node/edge feats. If False, the returned node/edge feature
dictionaries will be empty. Default: True. dictionaries will be empty. Default: True.
use_graphbolt : bool, optional
Whether to load GraphBolt partition. Default: False.
Returns Returns
------- -------
...@@ -192,10 +245,13 @@ def load_partition(part_config, part_id, load_feats=True): ...@@ -192,10 +245,13 @@ def load_partition(part_config, part_id, load_feats=True):
"part-{}".format(part_id) in part_metadata "part-{}".format(part_id) in part_metadata
), "part-{} does not exist".format(part_id) ), "part-{} does not exist".format(part_id)
part_files = part_metadata["part-{}".format(part_id)] part_files = part_metadata["part-{}".format(part_id)]
part_graph_field = "part_graph"
if use_graphbolt:
part_graph_field = "part_graph_graphbolt"
assert ( assert (
"part_graph" in part_files part_graph_field in part_files
), "the partition does not contain graph structure." ), f"the partition does not contain graph structure: {part_graph_field}"
partition_path = relative_to_config(part_files["part_graph"]) partition_path = relative_to_config(part_files[part_graph_field])
logging.info( logging.info(
"Start to load partition from %s which is " "Start to load partition from %s which is "
"%d bytes. It may take non-trivial " "%d bytes. It may take non-trivial "
...@@ -203,54 +259,24 @@ def load_partition(part_config, part_id, load_feats=True): ...@@ -203,54 +259,24 @@ def load_partition(part_config, part_id, load_feats=True):
partition_path, partition_path,
os.path.getsize(partition_path), os.path.getsize(partition_path),
) )
graph = load_graphs(partition_path)[0][0] graph = (
logging.info("Finished loading partition.") torch.load(partition_path)
if use_graphbolt
assert ( else load_graphs(partition_path)[0][0]
NID in graph.ndata )
), "the partition graph should contain node mapping to global node ID" logging.info("Finished loading partition from %s.", partition_path)
assert (
EID in graph.edata
), "the partition graph should contain edge mapping to global edge ID"
gpb, graph_name, ntypes, etypes = load_partition_book(part_config, part_id) gpb, graph_name, ntypes, etypes = load_partition_book(part_config, part_id)
ntypes_list = list(ntypes.keys()) ntypes_list = list(ntypes.keys())
etypes_list = list(etypes.keys()) etypes_list = list(etypes.keys())
if "DGL_DIST_DEBUG" in os.environ: if "DGL_DIST_DEBUG" in os.environ:
for ntype in ntypes: _verify_func = (
ntype_id = ntypes[ntype] _verify_graphbolt_partition
# graph.ndata[NID] are global homogeneous node IDs. if use_graphbolt
nids = F.boolean_mask( else _verify_dgl_partition
graph.ndata[NID], _get_inner_node_mask(graph, ntype_id) )
) _verify_func(graph, part_id, gpb, ntypes, etypes)
partids1 = gpb.nid2partid(nids)
_, per_type_nids = gpb.map_to_per_ntype(nids)
partids2 = gpb.nid2partid(per_type_nids, ntype)
assert np.all(F.asnumpy(partids1 == part_id)), (
"Unexpected partition IDs are found in the loaded partition "
"while querying via global homogeneous node IDs."
)
assert np.all(F.asnumpy(partids2 == part_id)), (
"Unexpected partition IDs are found in the loaded partition "
"while querying via type-wise node IDs."
)
for etype in etypes:
etype_id = etypes[etype]
# graph.edata[EID] are global homogeneous edge IDs.
eids = F.boolean_mask(
graph.edata[EID], _get_inner_edge_mask(graph, etype_id)
)
partids1 = gpb.eid2partid(eids)
_, per_type_eids = gpb.map_to_per_etype(eids)
partids2 = gpb.eid2partid(per_type_eids, etype)
assert np.all(F.asnumpy(partids1 == part_id)), (
"Unexpected partition IDs are found in the loaded partition "
"while querying via global homogeneous edge IDs."
)
assert np.all(F.asnumpy(partids2 == part_id)), (
"Unexpected partition IDs are found in the loaded partition "
"while querying via type-wise edge IDs."
)
node_feats = {} node_feats = {}
edge_feats = {} edge_feats = {}
...@@ -1357,9 +1383,9 @@ def dgl_partition_to_graphbolt( ...@@ -1357,9 +1383,9 @@ def dgl_partition_to_graphbolt(
torch.save(csc_graph, csc_graph_path) torch.save(csc_graph, csc_graph_path)
# Update graph path. # Update graph path.
new_part_meta[f"part-{part_id}"]["gb_part_graph"] = os.path.relpath( new_part_meta[f"part-{part_id}"][
csc_graph_path, os.path.dirname(part_config) "part_graph_graphbolt"
) ] = os.path.relpath(csc_graph_path, os.path.dirname(part_config))
# Update partition config. # Update partition config.
_dump_part_config(part_config, new_part_meta) _dump_part_config(part_config, new_part_meta)
......
...@@ -712,11 +712,9 @@ def test_dgl_partition_to_graphbolt_homo( ...@@ -712,11 +712,9 @@ def test_dgl_partition_to_graphbolt_homo(
orig_g = dgl.load_graphs( orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl") os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0] )[0][0]
new_g = th.load( new_g = load_partition(
os.path.join( part_config, part_id, load_feats=False, use_graphbolt=True
test_dir, f"part{part_id}/fused_csc_sampling_graph.pt" )[0]
)
)
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc() orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr) assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
...@@ -784,11 +782,9 @@ def test_dgl_partition_to_graphbolt_hetero( ...@@ -784,11 +782,9 @@ def test_dgl_partition_to_graphbolt_hetero(
orig_g = dgl.load_graphs( orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl") os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0] )[0][0]
new_g = th.load( new_g = load_partition(
os.path.join( part_config, part_id, load_feats=False, use_graphbolt=True
test_dir, f"part{part_id}/fused_csc_sampling_graph.pt" )[0]
)
)
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc() orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr) assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment