Unverified Commit 571340da authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] update partition_graph to support graphbolt (#7053)

parent df6b3250
......@@ -638,6 +638,8 @@ def partition_graph(
num_trainers_per_machine=1,
objtype="cut",
graph_formats=None,
use_graphbolt=False,
**kwargs,
):
"""Partition a graph for distributed training and store the partitions on files.
......@@ -811,6 +813,10 @@ def partition_graph(
``csc`` and ``csr``. If not specified, save one format only according to what
format is available. If multiple formats are available, selection priority
from high to low is ``coo``, ``csc``, ``csr``.
use_graphbolt : bool, optional
Whether to save partitions in GraphBolt format. Default: False.
kwargs : dict
Other keyword arguments for converting DGL partitions to GraphBolt.
Returns
-------
......@@ -1298,7 +1304,8 @@ def partition_graph(
)
)
_dump_part_config(f"{out_path}/{graph_name}.json", part_metadata)
part_config = os.path.join(out_path, graph_name + ".json")
_dump_part_config(part_config, part_metadata)
num_cuts = sim_g.num_edges() - tot_num_inner_edges
if num_parts == 1:
......@@ -1309,6 +1316,12 @@ def partition_graph(
)
)
if use_graphbolt:
dgl_partition_to_graphbolt(
part_config,
**kwargs,
)
if return_mapping:
return orig_nids, orig_eids
......
......@@ -19,7 +19,6 @@ import pytest
import torch as th
from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import (
dgl_partition_to_graphbolt,
DistEmbedding,
DistGraph,
DistGraphServer,
......@@ -155,10 +154,9 @@ def check_server_client_empty(
# Partition the graph
num_parts = 1
graph_name = "dist_graph_test_1"
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
partition_graph(
g, graph_name, num_parts, "/tmp/dist_graph", use_graphbolt=use_graphbolt
)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
......@@ -586,10 +584,9 @@ def check_server_client(
graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
partition_graph(
g, graph_name, num_parts, "/tmp/dist_graph", use_graphbolt=use_graphbolt
)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
......@@ -662,10 +659,8 @@ def check_server_client_hierarchy(
num_parts,
"/tmp/dist_graph",
num_trainers_per_machine=num_clients,
use_graphbolt=use_graphbolt,
)
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
......@@ -928,10 +923,9 @@ def check_server_client_hetero(
# Partition the graph
num_parts = 1
graph_name = "dist_graph_test_3"
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
partition_graph(
g, graph_name, num_parts, "/tmp/dist_graph", use_graphbolt=use_graphbolt
)
# let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine.
......
......@@ -944,3 +944,165 @@ def test_not_sorted_node_edge_map():
gpb, _, _, _ = load_partition_book(part_config, 1)
assert gpb.local_ntype_offset == [0, 300, 700]
assert gpb.local_etype_offset == [0, 500, 1100, 1800, 2600]
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("store_eids", [True, False])
@pytest.mark.parametrize("store_inner_node", [True, False])
@pytest.mark.parametrize("store_inner_edge", [True, False])
@pytest.mark.parametrize("debug_mode", [True, False])
def test_partition_graph_graphbolt_homo(
part_method,
num_parts,
store_eids,
store_inner_node,
store_inner_edge,
debug_mode,
):
reset_envs()
if debug_mode:
os.environ["DGL_DIST_DEBUG"] = "1"
with tempfile.TemporaryDirectory() as test_dir:
g = create_random_graph(1000)
graph_name = "test"
partition_graph(
g,
graph_name,
num_parts,
test_dir,
part_method=part_method,
use_graphbolt=True,
store_eids=store_eids,
store_inner_node=store_inner_node,
store_inner_edge=store_inner_edge,
)
part_config = os.path.join(test_dir, f"{graph_name}.json")
for part_id in range(num_parts):
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = load_partition(
part_config, part_id, load_feats=False, use_graphbolt=True
)[0]
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None
assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
)
if store_inner_node or debug_mode:
assert th.equal(
orig_g.ndata["inner_node"],
new_g.node_attributes["inner_node"],
)
else:
assert "inner_node" not in new_g.node_attributes
if store_eids or debug_mode:
assert th.equal(
orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID],
)
else:
assert dgl.EID not in new_g.edge_attributes
if store_inner_edge or debug_mode:
assert th.equal(
orig_g.edata["inner_edge"][orig_eids],
new_g.edge_attributes["inner_edge"],
)
else:
assert "inner_edge" not in new_g.edge_attributes
assert new_g.type_per_edge is None
assert new_g.node_type_to_id is None
assert new_g.edge_type_to_id is None
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("store_eids", [True, False])
@pytest.mark.parametrize("store_inner_node", [True, False])
@pytest.mark.parametrize("store_inner_edge", [True, False])
@pytest.mark.parametrize("debug_mode", [True, False])
def test_partition_graph_graphbolt_hetero(
part_method,
num_parts,
store_eids,
store_inner_node,
store_inner_edge,
debug_mode,
):
reset_envs()
if debug_mode:
os.environ["DGL_DIST_DEBUG"] = "1"
with tempfile.TemporaryDirectory() as test_dir:
g = create_random_hetero()
graph_name = "test"
partition_graph(
g,
graph_name,
num_parts,
test_dir,
part_method=part_method,
use_graphbolt=True,
store_eids=store_eids,
store_inner_node=store_inner_node,
store_inner_edge=store_inner_edge,
)
part_config = os.path.join(test_dir, f"{graph_name}.json")
for part_id in range(num_parts):
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = load_partition(
part_config, part_id, load_feats=False, use_graphbolt=True
)[0]
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
)
if store_inner_node or debug_mode:
assert th.equal(
orig_g.ndata["inner_node"],
new_g.node_attributes["inner_node"],
)
else:
assert "inner_node" not in new_g.node_attributes
if debug_mode:
assert th.equal(
orig_g.ndata[dgl.NTYPE], new_g.node_attributes[dgl.NTYPE]
)
else:
assert dgl.NTYPE not in new_g.node_attributes
if store_eids or debug_mode:
assert th.equal(
orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID],
)
else:
assert dgl.EID not in new_g.edge_attributes
if store_inner_edge or debug_mode:
assert th.equal(
orig_g.edata["inner_edge"],
new_g.edge_attributes["inner_edge"],
)
else:
assert "inner_edge" not in new_g.edge_attributes
if debug_mode:
assert th.equal(
orig_g.edata[dgl.ETYPE][orig_eids],
new_g.edge_attributes[dgl.ETYPE],
)
else:
assert dgl.ETYPE not in new_g.edge_attributes
assert th.equal(
orig_g.edata[dgl.ETYPE][orig_eids], new_g.type_per_edge
)
for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert new_g.node_type_offset is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment