Unverified Commit b35757a0 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Enable `FusedCSCSamplingGraph.in_subgraph()` to accept dict (#6550)

parent 2d689235
......@@ -283,7 +283,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""
return self._metadata
def in_subgraph(self, nodes: torch.Tensor) -> torch.ScriptObject:
def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> FusedSampledSubgraphImpl:
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
......@@ -291,20 +293,53 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters
----------
nodes : torch.Tensor
The nodes to form the subgraph which are type agnostic.
nodes: torch.Tensor or Dict[str, torch.Tensor]
IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids.
- If `nodes` is a dictionary: The keys should be node type and
ids inside are heterogeneous ids.
Returns
-------
torch.classes.graphbolt.SampledSubgraph
FusedSampledSubgraphImpl
The in subgraph.
Examples
--------
>>> import dgl.graphbolt as gb
>>> import torch
>>> total_num_nodes = 5
>>> total_num_edges = 12
>>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = {
... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs)
defaultdict(<class 'list'>, {
'N0:R0:N0': (tensor([]), tensor([])),
'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
'N1:R2:N0': (tensor([0, 1]), tensor([1, 1])),
'N1:R3:N1': (tensor([0, 1, 2]), tensor([1, 2, 2]))}
"""
if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes)
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
# Ensure that there are no duplicate nodes.
assert len(torch.unique(nodes)) == len(
nodes
), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes)
return self._convert_to_sampled_subgraph(_in_subgraph)
......
......@@ -562,7 +562,10 @@ def test_in_subgraph_heterogeneous():
)
# Extract in subgraph.
nodes = torch.LongTensor([1, 3, 4])
nodes = {
"N0": torch.LongTensor([1]),
"N1": torch.LongTensor([1, 2]),
}
in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment