Unverified Commit 567a9df2 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add csc sampling graph python side code (#5700)

parent c83350da
/**
* Copyright (c) 2023 by Contributors
* @file python_binding.cc
* @brief Graph bolt library Python binding.
*/
#include <graphbolt/csc_sampling_graph.h>
namespace graphbolt {
namespace sampling {
TORCH_LIBRARY(graphbolt, m) {
m.class_<CSCSamplingGraph>("CSCSamplingGraph")
.def("num_nodes", &CSCSamplingGraph::NumNodes)
.def("num_edges", &CSCSamplingGraph::NumEdges)
.def("csc_indptr", &CSCSamplingGraph::CSCIndptr)
.def("indices", &CSCSamplingGraph::Indices)
.def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge);
m.def("from_csc", &CSCSamplingGraph::FromCSC);
}
} // namespace sampling
} // namespace graphbolt
"""GraphBolt""" """Graphbolt."""
import os
import sys
import torch
from .._ffi import libinfo
from .graph_storage import *
from .itemset import * from .itemset import *
def load_graphbolt():
"""Load Graphbolt C++ library"""
version = torch.__version__.split("+", maxsplit=1)[0]
if sys.platform.startswith("linux"):
basename = f"libgraphbolt_pytorch_{version}.so"
elif sys.platform.startswith("darwin"):
basename = f"libgraphbolt_pytorch_{version}.dylib"
elif sys.platform.startswith("win"):
basename = f"graphbolt_pytorch_{version}.dll"
else:
raise NotImplementedError("Unsupported system: %s" % sys.platform)
dirname = os.path.dirname(libinfo.find_lib_path()[0])
path = os.path.join(dirname, "graphbolt", basename)
if not os.path.exists(path):
raise FileNotFoundError(
f"Cannot find DGL C++ graphbolt library at {path}"
)
try:
torch.classes.load_library(path)
except Exception: # pylint: disable=W0703
raise ImportError("Cannot load Graphbolt C++ library")
load_graphbolt()
"""Graphbolt graph module."""
from .csc_sampling_graph import *
"""CSC format sampling graph."""
# pylint: disable= invalid-name
from typing import Dict, Optional, Tuple
import torch
class GraphMetadata:
r"""Class for metadata of csc sampling graph."""
def __init__(
self,
node_type_to_id: Dict[str, int],
edge_type_to_id: Dict[Tuple[str, str, str], int],
):
"""Initialize the GraphMetadata object.
Parameters
----------
node_type_to_id : Dict[str, int]
Dictionary from node types to node type IDs.
edge_type_to_id : Dict[Tuple[str, str, str], int]
Dictionary from edge types to edge type IDs.
Raises
------
AssertionError
If any of the assertions fail.
"""
node_types = list(node_type_to_id.keys())
edge_types = list(edge_type_to_id.keys())
node_type_ids = list(node_type_to_id.values())
edge_type_ids = list(edge_type_to_id.values())
# Validate node_type_to_id.
assert all(
isinstance(x, str) for x in node_types
), "Node type name should be string."
assert all(
isinstance(x, int) for x in node_type_ids
), "Node type id should be int."
assert len(node_type_ids) == len(
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
edges = set()
for edge_type in edge_types:
src, edge, dst = edge_type
assert isinstance(edge, str), "Edge type name should be string."
assert edge not in edges, f"Edge type {edge} is defined repeatedly."
edges.add(edge)
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
assert (
dst in node_types
), f"Unrecognized node type {dst} in edge type {edge_type}"
assert all(
isinstance(x, int) for x in edge_type_ids
), "Edge type id should be int."
assert len(edge_type_ids) == len(
set(edge_type_ids)
), "Multiple edge types shoud not be mapped to a same id."
self.node_type_to_id = node_type_to_id
self.edge_type_to_id = edge_type_to_id
class CSCSamplingGraph:
r"""Class for CSC sampling graph."""
def __repr__(self):
return _csc_sampling_graph_str(self)
def __init__(
self, c_csc_graph: torch.ScriptObject, metadata: Optional[GraphMetadata]
):
self._c_csc_graph = c_csc_graph
self._metadata = metadata
@property
def num_nodes(self) -> int:
"""Returns the number of nodes in the graph.
Returns
-------
int
The number of rows in the dense format.
"""
return self._c_csc_graph.num_nodes()
@property
def num_edges(self) -> int:
"""Returns the number of edges in the graph.
Returns
-------
int
The number of edges in the graph.
"""
return self._c_csc_graph.num_edges()
@property
def csc_indptr(self) -> torch.tensor:
"""Returns the indices pointer in the CSC graph.
Returns
-------
torch.tensor
The indices pointer in the CSC graph. An integer tensor with
shape `(num_nodes+1,)`.
"""
return self._c_csc_graph.csc_indptr()
@property
def indices(self) -> torch.tensor:
"""Returns the indices in the CSC graph.
Returns
-------
torch.tensor
The indices in the CSC graph. An integer tensor with shape
`(num_edges,)`.
Notes
-------
It is assumed that edges of each node are already sorted by edge type
ids.
"""
return self._c_csc_graph.indices()
@property
def node_type_offset(self) -> Optional[torch.Tensor]:
"""Returns the node type offset tensor if present.
Returns
-------
torch.Tensor or None
If present, returns a 1D integer tensor of shape
`(num_node_types + 1,)`. The tensor is in ascending order as nodes
of the same type have continuous IDs, and larger node IDs are
paired with larger node type IDs. The first value is 0 and last
value is the number of nodes. And nodes with IDs between
`node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.
"""
return self._c_csc_graph.node_type_offset()
@property
def type_per_edge(self) -> Optional[torch.Tensor]:
"""Returns the edge type tensor if present.
Returns
-------
torch.Tensor or None
If present, returns a 1D integer tensor of shape (num_edges,)
containing the type of each edge in the graph.
"""
return self._c_csc_graph.type_per_edge()
@property
def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph.
Returns
-------
GraphMetadata or None
If present, returns the metadata of the graph.
"""
return self._metadata
def from_csc(
csc_indptr: torch.Tensor,
indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None,
type_per_edge: Optional[torch.tensor] = None,
metadata: Optional[GraphMetadata] = None,
) -> CSCSamplingGraph:
"""Create a CSCSamplingGraph object from a CSC representation.
Parameters
----------
csc_indptr : torch.Tensor
Pointer to the start of each row in the `indices`. An integer tensor
with shape `(num_nodes+1,)`.
indices : torch.Tensor
Column indices of the non-zero elements in the CSC graph. An integer
tensor with shape `(num_edges,)`.
node_type_offset : Optional[torch.tensor], optional
Offset of node types in the graph, by default None.
type_per_edge : Optional[torch.tensor], optional
Type ids of each edge in the graph, by default None.
metadata: Optional[GraphMetadata], optional
Metadata of the graph, by default None.
Returns
-------
CSCSamplingGraph
The created CSCSamplingGraph object.
Examples
--------
>>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
>>> etypes = {('n1', 'e1', 'n2'): 0, ('n1', 'e2', 'n3'): 1}
>>> metadata = graphbolt.GraphMetadata(ntypes, etypes)
>>> csc_indptr = torch.tensor([0, 2, 5, 7])
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
>>> node_type_offset = torch.tensor([0, 1, 2, 3])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
>>> graph = graphbolt.from_csc(csc_indptr, indices, node_type_offset, \
>>> type_per_edge, metadata)
>>> print(graph)
CSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
num_nodes=3, num_edges=7)
"""
if metadata and metadata.node_type_to_id and node_type_offset is not None:
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
0
), "node_type_offset length should be |ntypes| + 1."
return CSCSamplingGraph(
torch.ops.graphbolt.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge
),
metadata,
)
def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str:
"""Internal function for converting a csc sampling graph to string
representation.
"""
csc_indptr_str = str(graph.csc_indptr)
indices_str = str(graph.indices)
meta_str = f"num_nodes={graph.num_nodes}, num_edges={graph.num_edges}"
prefix = f"{type(graph).__name__}("
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
final_str = (
"csc_indptr="
+ _add_indent(csc_indptr_str, len("csc_indptr="))
+ ",\n"
+ "indices="
+ _add_indent(indices_str, len("indices="))
+ ",\n"
+ meta_str
+ ")"
)
final_str = prefix + _add_indent(final_str, len(prefix))
return final_str
import unittest
import backend as F
import dgl.graphbolt as gb
import pytest
import torch
torch.manual_seed(3407)
def get_metadata(num_ntypes, num_etypes):
ntypes = {f"n{i}": i for i in range(num_ntypes)}
etypes = {}
count = 0
for n1 in range(num_ntypes):
for n2 in range(n1, num_ntypes):
if count >= num_etypes:
break
etypes.update({(f"n{n1}", f"e{count}", f"n{n2}"): count})
count += 1
return gb.GraphMetadata(ntypes, etypes)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("num_nodes", [0, 1, 10, 100, 1000])
def test_empty_graph(num_nodes):
csc_indptr = torch.zeros((num_nodes + 1,), dtype=int)
indices = torch.tensor([])
graph = gb.from_csc(csc_indptr, indices)
assert graph.num_edges == 0
assert graph.num_nodes == num_nodes
assert torch.equal(graph.csc_indptr, csc_indptr)
assert torch.equal(graph.indices, indices)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("num_nodes", [0, 1, 10, 100, 1000])
def test_hetero_empty_graph(num_nodes):
csc_indptr = torch.zeros((num_nodes + 1,), dtype=int)
indices = torch.tensor([])
metadata = get_metadata(num_ntypes=3, num_etypes=5)
# Some node types have no nodes.
if num_nodes == 0:
node_type_offset = torch.zeros((4,), dtype=int)
else:
node_type_offset = torch.sort(torch.randint(0, num_nodes, (4,)))[0]
node_type_offset[0] = 0
node_type_offset[-1] = num_nodes
type_per_edge = torch.tensor([])
graph = gb.from_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
metadata,
)
assert graph.num_edges == 0
assert graph.num_nodes == num_nodes
assert torch.equal(graph.csc_indptr, csc_indptr)
assert torch.equal(graph.indices, indices)
assert graph.metadata.node_type_to_id == metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == metadata.edge_type_to_id
assert torch.equal(graph.node_type_offset, node_type_offset)
assert torch.equal(graph.type_per_edge, type_per_edge)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"ntypes", [{"n1": 1, "n2": 1}, {5: 1, "n2": 2}, {"n1": 1.5, "n2": 2.0}]
)
def test_metadata_with_ntype_exception(ntypes):
with pytest.raises(Exception):
gb.GraphMetadata(ntypes, {("n1", "e1", "n2"): 1})
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"etypes",
[
{("n1", 5, "n12"): 1},
{"e1": 1},
{("n1", "e1"): 1},
{("n1", "e1", 10): 1},
{("n1", "e1", "n2"): 1, ("n1", "e2", "n3"): 1},
{("n1", "e1", "n2"): 1, ("n1", "e1", "n3"): 2},
{("n1", "e1", "n10"): 1},
{("n1", "e1", "n2"): 1.5},
],
)
def test_metadata_with_etype_exception(etypes):
with pytest.raises(Exception):
gb.GraphMetadata({"n1": 0, "n2": 1, "n3": 2}, etypes)
def random_homo_graph(num_nodes, num_edges):
csc_indptr = torch.randint(0, num_edges, (num_nodes + 1,))
csc_indptr = torch.sort(csc_indptr)[0]
csc_indptr[0] = 0
csc_indptr[-1] = num_edges
indices = torch.randint(0, num_nodes, (num_edges,))
return csc_indptr, indices
def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
csc_indptr, indices = random_homo_graph(num_nodes, num_edges)
metadata = get_metadata(num_ntypes, num_etypes)
# Randomly get node type split point.
node_type_offset = torch.sort(
torch.randint(0, num_nodes, (num_ntypes + 1,))
)[0]
node_type_offset[0] = 0
node_type_offset[-1] = num_nodes
type_per_edge = []
for i in range(num_nodes):
num = csc_indptr[i + 1] - csc_indptr[i]
type_per_edge.append(
torch.sort(torch.randint(0, num_etypes, (num,)))[0]
)
type_per_edge = torch.cat(type_per_edge, dim=0)
return (csc_indptr, indices, node_type_offset, type_per_edge, metadata)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"num_nodes, num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)]
)
def test_homo_graph(num_nodes, num_edges):
csc_indptr, indices = random_homo_graph(num_nodes, num_edges)
graph = gb.from_csc(csc_indptr, indices)
assert graph.num_nodes == num_nodes
assert graph.num_edges == num_edges
assert torch.equal(csc_indptr, graph.csc_indptr)
assert torch.equal(indices, graph.indices)
assert graph.metadata is None
assert graph.node_type_offset is None
assert graph.type_per_edge is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"num_nodes, num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)]
)
@pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)])
def test_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
metadata,
) = random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata
)
assert graph.num_nodes == num_nodes
assert graph.num_edges == num_edges
assert torch.equal(csc_indptr, graph.csc_indptr)
assert torch.equal(indices, graph.indices)
assert torch.equal(node_type_offset, graph.node_type_offset)
assert torch.equal(type_per_edge, graph.type_per_edge)
assert metadata.node_type_to_id == graph.metadata.node_type_to_id
assert metadata.edge_type_to_id == graph.metadata.edge_type_to_id
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"node_type_offset",
[
torch.tensor([0, 1]),
torch.tensor([0, 1, 5, 6, 10]),
torch.tensor([0, 1, 10]),
],
)
def test_node_type_offset_wrong_legnth(node_type_offset):
num_ntypes = 3
csc_indptr, indices, _, type_per_edge, metadata = random_hetero_graph(
10, 50, num_ntypes, 5
)
with pytest.raises(Exception):
gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, metadata
)
if __name__ == "__main__":
test_empty_graph(10)
test_node_type_offset_wrong_legnth(torch.tensor([0, 1, 5]))
test_hetero_graph(10, 50, 3, 5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment