Unverified Commit 571340da authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] update partition_graph to support graphbolt (#7053)

parent df6b3250
...@@ -638,6 +638,8 @@ def partition_graph( ...@@ -638,6 +638,8 @@ def partition_graph(
num_trainers_per_machine=1, num_trainers_per_machine=1,
objtype="cut", objtype="cut",
graph_formats=None, graph_formats=None,
use_graphbolt=False,
**kwargs,
): ):
"""Partition a graph for distributed training and store the partitions on files. """Partition a graph for distributed training and store the partitions on files.
...@@ -811,6 +813,10 @@ def partition_graph( ...@@ -811,6 +813,10 @@ def partition_graph(
``csc`` and ``csr``. If not specified, save one format only according to what ``csc`` and ``csr``. If not specified, save one format only according to what
format is available. If multiple formats are available, selection priority format is available. If multiple formats are available, selection priority
from high to low is ``coo``, ``csc``, ``csr``. from high to low is ``coo``, ``csc``, ``csr``.
use_graphbolt : bool, optional
Whether to save partitions in GraphBolt format. Default: False.
kwargs : dict
Other keyword arguments for converting DGL partitions to GraphBolt.
Returns Returns
------- -------
...@@ -1298,7 +1304,8 @@ def partition_graph( ...@@ -1298,7 +1304,8 @@ def partition_graph(
) )
) )
_dump_part_config(f"{out_path}/{graph_name}.json", part_metadata) part_config = os.path.join(out_path, graph_name + ".json")
_dump_part_config(part_config, part_metadata)
num_cuts = sim_g.num_edges() - tot_num_inner_edges num_cuts = sim_g.num_edges() - tot_num_inner_edges
if num_parts == 1: if num_parts == 1:
...@@ -1309,6 +1316,12 @@ def partition_graph( ...@@ -1309,6 +1316,12 @@ def partition_graph(
) )
) )
if use_graphbolt:
dgl_partition_to_graphbolt(
part_config,
**kwargs,
)
if return_mapping: if return_mapping:
return orig_nids, orig_eids return orig_nids, orig_eids
......
...@@ -19,7 +19,6 @@ import pytest ...@@ -19,7 +19,6 @@ import pytest
import torch as th import torch as th
from dgl.data.utils import load_graphs, save_graphs from dgl.data.utils import load_graphs, save_graphs
from dgl.distributed import ( from dgl.distributed import (
dgl_partition_to_graphbolt,
DistEmbedding, DistEmbedding,
DistGraph, DistGraph,
DistGraphServer, DistGraphServer,
...@@ -155,10 +154,9 @@ def check_server_client_empty( ...@@ -155,10 +154,9 @@ def check_server_client_empty(
# Partition the graph # Partition the graph
num_parts = 1 num_parts = 1
graph_name = "dist_graph_test_1" graph_name = "dist_graph_test_1"
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") partition_graph(
if use_graphbolt: g, graph_name, num_parts, "/tmp/dist_graph", use_graphbolt=use_graphbolt
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json") )
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
...@@ -586,10 +584,9 @@ def check_server_client( ...@@ -586,10 +584,9 @@ def check_server_client(
graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}" graph_name = f"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}"
g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1) g.ndata["features"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1) g.edata["features"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") partition_graph(
if use_graphbolt: g, graph_name, num_parts, "/tmp/dist_graph", use_graphbolt=use_graphbolt
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json") )
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
...@@ -662,10 +659,8 @@ def check_server_client_hierarchy( ...@@ -662,10 +659,8 @@ def check_server_client_hierarchy(
num_parts, num_parts,
"/tmp/dist_graph", "/tmp/dist_graph",
num_trainers_per_machine=num_clients, num_trainers_per_machine=num_clients,
use_graphbolt=use_graphbolt,
) )
if use_graphbolt:
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json")
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
...@@ -928,10 +923,9 @@ def check_server_client_hetero( ...@@ -928,10 +923,9 @@ def check_server_client_hetero(
# Partition the graph # Partition the graph
num_parts = 1 num_parts = 1
graph_name = "dist_graph_test_3" graph_name = "dist_graph_test_3"
partition_graph(g, graph_name, num_parts, "/tmp/dist_graph") partition_graph(
if use_graphbolt: g, graph_name, num_parts, "/tmp/dist_graph", use_graphbolt=use_graphbolt
part_config = os.path.join("/tmp/dist_graph", f"{graph_name}.json") )
dgl_partition_to_graphbolt(part_config)
# let's just test on one partition for now. # let's just test on one partition for now.
# We cannot run multiple servers and clients on the same machine. # We cannot run multiple servers and clients on the same machine.
......
...@@ -944,3 +944,165 @@ def test_not_sorted_node_edge_map(): ...@@ -944,3 +944,165 @@ def test_not_sorted_node_edge_map():
gpb, _, _, _ = load_partition_book(part_config, 1) gpb, _, _, _ = load_partition_book(part_config, 1)
assert gpb.local_ntype_offset == [0, 300, 700] assert gpb.local_ntype_offset == [0, 300, 700]
assert gpb.local_etype_offset == [0, 500, 1100, 1800, 2600] assert gpb.local_etype_offset == [0, 500, 1100, 1800, 2600]
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("store_eids", [True, False])
@pytest.mark.parametrize("store_inner_node", [True, False])
@pytest.mark.parametrize("store_inner_edge", [True, False])
@pytest.mark.parametrize("debug_mode", [True, False])
def test_partition_graph_graphbolt_homo(
part_method,
num_parts,
store_eids,
store_inner_node,
store_inner_edge,
debug_mode,
):
reset_envs()
if debug_mode:
os.environ["DGL_DIST_DEBUG"] = "1"
with tempfile.TemporaryDirectory() as test_dir:
g = create_random_graph(1000)
graph_name = "test"
partition_graph(
g,
graph_name,
num_parts,
test_dir,
part_method=part_method,
use_graphbolt=True,
store_eids=store_eids,
store_inner_node=store_inner_node,
store_inner_edge=store_inner_edge,
)
part_config = os.path.join(test_dir, f"{graph_name}.json")
for part_id in range(num_parts):
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = load_partition(
part_config, part_id, load_feats=False, use_graphbolt=True
)[0]
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None
assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
)
if store_inner_node or debug_mode:
assert th.equal(
orig_g.ndata["inner_node"],
new_g.node_attributes["inner_node"],
)
else:
assert "inner_node" not in new_g.node_attributes
if store_eids or debug_mode:
assert th.equal(
orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID],
)
else:
assert dgl.EID not in new_g.edge_attributes
if store_inner_edge or debug_mode:
assert th.equal(
orig_g.edata["inner_edge"][orig_eids],
new_g.edge_attributes["inner_edge"],
)
else:
assert "inner_edge" not in new_g.edge_attributes
assert new_g.type_per_edge is None
assert new_g.node_type_to_id is None
assert new_g.edge_type_to_id is None
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("store_eids", [True, False])
@pytest.mark.parametrize("store_inner_node", [True, False])
@pytest.mark.parametrize("store_inner_edge", [True, False])
@pytest.mark.parametrize("debug_mode", [True, False])
def test_partition_graph_graphbolt_hetero(
part_method,
num_parts,
store_eids,
store_inner_node,
store_inner_edge,
debug_mode,
):
reset_envs()
if debug_mode:
os.environ["DGL_DIST_DEBUG"] = "1"
with tempfile.TemporaryDirectory() as test_dir:
g = create_random_hetero()
graph_name = "test"
partition_graph(
g,
graph_name,
num_parts,
test_dir,
part_method=part_method,
use_graphbolt=True,
store_eids=store_eids,
store_inner_node=store_inner_node,
store_inner_edge=store_inner_edge,
)
part_config = os.path.join(test_dir, f"{graph_name}.json")
for part_id in range(num_parts):
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = load_partition(
part_config, part_id, load_feats=False, use_graphbolt=True
)[0]
orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
assert th.equal(
orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]
)
if store_inner_node or debug_mode:
assert th.equal(
orig_g.ndata["inner_node"],
new_g.node_attributes["inner_node"],
)
else:
assert "inner_node" not in new_g.node_attributes
if debug_mode:
assert th.equal(
orig_g.ndata[dgl.NTYPE], new_g.node_attributes[dgl.NTYPE]
)
else:
assert dgl.NTYPE not in new_g.node_attributes
if store_eids or debug_mode:
assert th.equal(
orig_g.edata[dgl.EID][orig_eids],
new_g.edge_attributes[dgl.EID],
)
else:
assert dgl.EID not in new_g.edge_attributes
if store_inner_edge or debug_mode:
assert th.equal(
orig_g.edata["inner_edge"],
new_g.edge_attributes["inner_edge"],
)
else:
assert "inner_edge" not in new_g.edge_attributes
if debug_mode:
assert th.equal(
orig_g.edata[dgl.ETYPE][orig_eids],
new_g.edge_attributes[dgl.ETYPE],
)
else:
assert dgl.ETYPE not in new_g.edge_attributes
assert th.equal(
orig_g.edata[dgl.ETYPE][orig_eids], new_g.type_per_edge
)
for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert new_g.node_type_offset is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment