Unverified Commit b35757a0 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Enable `FusedCSCSamplingGraph.in_subgraph()` to accept dict (#6550)

parent 2d689235
...@@ -283,7 +283,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -283,7 +283,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
""" """
return self._metadata return self._metadata
def in_subgraph(self, nodes: torch.Tensor) -> torch.ScriptObject: def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> FusedSampledSubgraphImpl:
"""Return the subgraph induced on the inbound edges of the given nodes. """Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming An in subgraph is equivalent to creating a new graph using the incoming
...@@ -291,20 +293,53 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -291,20 +293,53 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters Parameters
---------- ----------
nodes : torch.Tensor nodes: torch.Tensor or Dict[str, torch.Tensor]
The nodes to form the subgraph which are type agnostic. IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids.
- If `nodes` is a dictionary: The keys should be node type and
ids inside are heterogeneous ids.
Returns Returns
------- -------
torch.classes.graphbolt.SampledSubgraph FusedSampledSubgraphImpl
The in subgraph. The in subgraph.
Examples
--------
>>> import dgl.graphbolt as gb
>>> import torch
>>> total_num_nodes = 5
>>> total_num_edges = 12
>>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = {
... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs)
defaultdict(<class 'list'>, {
'N0:R0:N0': (tensor([]), tensor([])),
'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
'N1:R2:N0': (tensor([0, 1]), tensor([1, 1])),
'N1:R3:N1': (tensor([0, 1, 2]), tensor([1, 2, 2]))}
""" """
if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes)
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
# Ensure that there are no duplicate nodes. # Ensure that there are no duplicate nodes.
assert len(torch.unique(nodes)) == len( assert len(torch.unique(nodes)) == len(
nodes nodes
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes) _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
return self._convert_to_sampled_subgraph(_in_subgraph) return self._convert_to_sampled_subgraph(_in_subgraph)
......
...@@ -562,7 +562,10 @@ def test_in_subgraph_heterogeneous(): ...@@ -562,7 +562,10 @@ def test_in_subgraph_heterogeneous():
) )
# Extract in subgraph. # Extract in subgraph.
nodes = torch.LongTensor([1, 3, 4]) nodes = {
"N0": torch.LongTensor([1]),
"N1": torch.LongTensor([1, 2]),
}
in_subgraph = graph.in_subgraph(nodes) in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph. # Verify in subgraph.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment