Unverified Commit 00972dee authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] convert dgl partition to CSCSamplingGraph (#5736)

parent 236ffa0f
......@@ -19,7 +19,6 @@ from . import (
container,
cuda,
dataloading,
distributed,
function,
ops,
random,
......@@ -64,3 +63,6 @@ from .frame import LazyFeature
from .global_config import is_libxsmm_enabled, use_libxsmm
from .utils import apply_each
from .mpops import *
if backend_name == "pytorch":
from . import distributed
......@@ -9,6 +9,7 @@ from .graph_services import *
from .kvstore import KVClient, KVServer
from .nn import *
from .partition import (
convert_dgl_partition_to_csc_sampling_graph,
load_partition,
load_partition_book,
load_partition_feats,
......
......@@ -25,6 +25,9 @@ from .graph_partition_book import (
RangePartitionBook,
)
if F.backend_name == "pytorch":
from .. import graphbolt
RESERVED_FIELD_DTYPE = {
"inner_node": F.uint8, # A flag indicates whether the node is inside a partition.
"inner_edge": F.uint8, # A flag indicates whether the edge is inside a partition.
......@@ -1200,3 +1203,54 @@ def partition_graph(
if return_mapping:
return orig_nids, orig_eids
def convert_dgl_partition_to_csc_sampling_graph(part_config):
"""Convert partitions of dgl to CSCSamplingGraph of GraphBolt.
This API converts `DGLGraph` partitions to `CSCSamplingGraph` which is
dedicated for sampling in `GraphBolt`. New graphs will be stored alongside
original graph as `csc_sampling_graph.tar`.
In the near future, partitions are supposed to be saved as
`CSCSamplingGraph` directly. At that time, this API should be deprecated.
Parameters
----------
part_config : str
The partition configuration JSON file.
"""
part_meta = _load_part_config(part_config)
num_parts = part_meta["num_parts"]
# Utility functions.
def init_type_per_edge(graph, gpb):
etype_ids = gpb.map_to_per_etype(graph.edata[EID])[0]
return etype_ids
# Iterate over partitions.
for part_id in range(num_parts):
graph, _, _, gpb, _, _, _ = load_partition(
part_config, part_id, load_feats=False
)
# Construct GraphMetadata.
_, _, ntypes, etypes = load_partition_book(part_config, part_id)
metadata = graphbolt.GraphMetadata(ntypes, etypes)
# Obtain CSC indtpr and indices.
indptr, indices, _ = graph.adj().csc()
# Initalize type per edge.
type_per_edge = init_type_per_edge(graph, gpb)
type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE])
# Sanity check.
assert len(type_per_edge) == graph.num_edges()
csc_graph = graphbolt.from_csc(
indptr, indices, None, type_per_edge, metadata
)
orig_graph_path = os.path.join(
os.path.dirname(part_config),
part_meta[f"part-{part_id}"]["part_graph"],
)
csc_graph_path = os.path.join(
os.path.dirname(orig_graph_path), "csc_sampling_graph.tar"
)
graphbolt.save_csc_sampling_graph(csc_graph, csc_graph_path)
......@@ -47,12 +47,9 @@ class GraphMetadata:
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
edges = set()
for edge_type in edge_types:
src, edge, dst = edge_type
assert isinstance(edge, str), "Edge type name should be string."
assert edge not in edges, f"Edge type {edge} is defined repeatedly."
edges.add(edge)
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
......
......@@ -9,6 +9,7 @@ import pytest
import torch as th
from dgl import function as fn
from dgl.distributed import (
convert_dgl_partition_to_csc_sampling_graph,
load_partition,
load_partition_book,
load_partition_feats,
......@@ -604,13 +605,13 @@ def test_RangePartitionBook():
expect_except = False
try:
gpb.to_canonical_etype(("node1", "edge2", "node2"))
except:
except BaseException:
expect_except = True
assert expect_except
expect_except = False
try:
gpb.to_canonical_etype("edge2")
except:
except BaseException:
expect_except = True
assert expect_except
......@@ -645,7 +646,7 @@ def test_RangePartitionBook():
expect_except = False
try:
HeteroDataName(False, "edge1", "feat")
except:
except BaseException:
expect_except = True
assert expect_except
data_name = HeteroDataName(False, c_etype, "feat")
......@@ -674,3 +675,65 @@ def test_UnknownPartitionBook():
except Exception as e:
if not isinstance(e, TypeError):
raise e
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
def test_convert_dgl_partition_to_csc_sampling_graph_homo(
part_method, num_parts
):
with tempfile.TemporaryDirectory() as test_dir:
g = create_random_graph(1000)
graph_name = "test"
partition_graph(
g, graph_name, num_parts, test_dir, part_method=part_method
)
part_config = os.path.join(test_dir, f"{graph_name}.json")
convert_dgl_partition_to_csc_sampling_graph(part_config)
for part_id in range(num_parts):
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = dgl.graphbolt.load_csc_sampling_graph(
os.path.join(test_dir, f"part{part_id}/csc_sampling_graph.tar")
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None
assert all(new_g.type_per_edge == 0)
for node_type, type_id in new_g.metadata.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
assert g.get_etype_id(edge_type) == type_id
@pytest.mark.parametrize("part_method", ["metis", "random"])
@pytest.mark.parametrize("num_parts", [1, 4])
def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
part_method, num_parts
):
with tempfile.TemporaryDirectory() as test_dir:
g = create_random_hetero()
graph_name = "test"
partition_graph(
g, graph_name, num_parts, test_dir, part_method=part_method
)
part_config = os.path.join(test_dir, f"{graph_name}.json")
convert_dgl_partition_to_csc_sampling_graph(part_config)
for part_id in range(num_parts):
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = dgl.graphbolt.load_csc_sampling_graph(
os.path.join(test_dir, f"part{part_id}/csc_sampling_graph.tar")
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
for node_type, type_id in new_g.metadata.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
assert g.get_etype_id(edge_type) == type_id
assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)
......@@ -98,7 +98,6 @@ def test_metadata_with_ntype_exception(ntypes):
{("n1", "e1"): 1},
{("n1", "e1", 10): 1},
{("n1", "e1", "n2"): 1, ("n1", "e2", "n3"): 1},
{("n1", "e1", "n2"): 1, ("n1", "e1", "n3"): 2},
{("n1", "e1", "n10"): 1},
{("n1", "e1", "n2"): 1.5},
],
......
......@@ -7,7 +7,7 @@ if [ $# -ne 1 ]; then
exit -1
fi
CMAKE_VARS="-DBUILD_CPP_TEST=ON -DUSE_OPENMP=ON"
CMAKE_VARS="-DBUILD_CPP_TEST=ON -DUSE_OPENMP=ON -DBUILD_GRAPHBOLT=ON"
# This is a semicolon-separated list of Python interpreters containing PyTorch.
# The value here is for CI. Replace it with your own or comment this whole
# statement for default Python interpreter.
......@@ -15,7 +15,7 @@ if [ "$1" != "cugraph" ]; then
# We do not build pytorch for cugraph because currently building
# pytorch against all the supported cugraph versions is not supported
# See issue: https://github.com/rapidsai/cudf/issues/8510
CMAKE_VARS="$CMAKE_VARS -DBUILD_TORCH=ON -DBUILD_GRAPHBOLT=ON -DTORCH_PYTHON_INTERPS=/opt/conda/envs/pytorch-ci/bin/python"
CMAKE_VARS="$CMAKE_VARS -DBUILD_TORCH=ON -DTORCH_PYTHON_INTERPS=/opt/conda/envs/pytorch-ci/bin/python"
else
# Disable sparse build as cugraph docker image lacks cuDNN.
CMAKE_VARS="$CMAKE_VARS -DBUILD_TORCH=OFF -DBUILD_SPARSE=OFF"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment