Unverified Commit 236ffa0f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add in_subgraph() API at the python level (#5743)

parent 2cb7c69d
...@@ -108,7 +108,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph( ...@@ -108,7 +108,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
torch::zeros({nonzero_idx.size(0) + 1}, indptr_.dtype()); torch::zeros({nonzero_idx.size(0) + 1}, indptr_.dtype());
compact_indptr.index_put_({Slice(1, None)}, indptr.index({nonzero_idx})); compact_indptr.index_put_({Slice(1, None)}, indptr.index({nonzero_idx}));
return c10::make_intrusive<SampledSubgraph>( return c10::make_intrusive<SampledSubgraph>(
compact_indptr.cumsum(0), torch::cat(indices_arr), nonzero_idx, compact_indptr.cumsum(0), torch::cat(indices_arr), nonzero_idx - 1,
torch::arange(0, NumNodes()), torch::cat(edge_ids_arr), torch::arange(0, NumNodes()), torch::cat(edge_ids_arr),
type_per_edge_ type_per_edge_
? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)} ? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)}
......
...@@ -27,7 +27,8 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -27,7 +27,8 @@ TORCH_LIBRARY(graphbolt, m) {
.def("csc_indptr", &CSCSamplingGraph::CSCIndptr) .def("csc_indptr", &CSCSamplingGraph::CSCIndptr)
.def("indices", &CSCSamplingGraph::Indices) .def("indices", &CSCSamplingGraph::Indices)
.def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset) .def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge); .def("type_per_edge", &CSCSamplingGraph::TypePerEdge)
.def("in_subgraph", &CSCSamplingGraph::InSubgraph);
m.def("from_csc", &CSCSamplingGraph::FromCSC); m.def("from_csc", &CSCSamplingGraph::FromCSC);
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph); m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph); m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph);
......
...@@ -173,6 +173,30 @@ class CSCSamplingGraph: ...@@ -173,6 +173,30 @@ class CSCSamplingGraph:
""" """
return self._metadata return self._metadata
def in_subgraph(self, nodes: torch.Tensor) -> torch.ScriptObject:
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
edges of the given nodes.
Parameters
----------
nodes : torch.Tensor
The nodes to form the subgraph which are type agnostic.
Returns
-------
SampledSubgraph
The in subgraph.
"""
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
# Ensure that there are no duplicate nodes.
assert len(torch.unique(nodes)) == len(
nodes
), "Nodes cannot have duplicate values."
return self._c_csc_graph.in_subgraph(nodes)
def from_csc( def from_csc(
csc_indptr: torch.Tensor, csc_indptr: torch.Tensor,
......
...@@ -274,3 +274,113 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes): ...@@ -274,3 +274,113 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
assert torch.equal(graph.type_per_edge, graph2.type_per_edge) assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_homogeneous():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
# Extract in subgraph.
nodes = torch.LongTensor([1, 3, 4])
in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph.
assert torch.equal(in_subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
)
assert torch.equal(in_subgraph.reverse_column_node_ids, nodes)
assert torch.equal(
in_subgraph.reverse_row_node_ids, torch.arange(0, num_nodes)
)
assert torch.equal(
in_subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
)
assert in_subgraph.type_per_edge is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_heterogeneous():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
node_type_0: [0, 1]
node_type_1: [2, 3, 4]
edge_type_0: node_type_0 -> node_type_0
edge_type_1: node_type_0 -> node_type_1
edge_type_2: node_type_1 -> node_type_0
edge_type_3: node_type_1 -> node_type_1
"""
# Initialize data.
num_nodes = 5
num_edges = 12
ntypes = {
"N0": 0,
"N1": 1,
}
etypes = {
("N0", "R0", "N0"): 0,
("N0", "R1", "N1"): 1,
("N1", "R2", "N0"): 2,
("N1", "R3", "N1"): 3,
}
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
node_type_offset = torch.LongTensor([0, 2, 5])
type_per_edge = torch.LongTensor([0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == num_nodes
assert all(type_per_edge < len(etypes))
# Construct CSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_csc(
indptr, indices, node_type_offset, type_per_edge, metadata
)
# Extract in subgraph.
nodes = torch.LongTensor([1, 3, 4])
in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph.
assert torch.equal(in_subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
)
assert torch.equal(in_subgraph.reverse_column_node_ids, nodes)
assert torch.equal(
in_subgraph.reverse_row_node_ids, torch.arange(0, num_nodes)
)
assert torch.equal(
in_subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
)
assert torch.equal(
in_subgraph.type_per_edge, torch.LongTensor([2, 2, 1, 3, 1, 3, 3])
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment