Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b35757a0
Unverified
Commit
b35757a0
authored
Nov 09, 2023
by
Mingbang Wang
Committed by
GitHub
Nov 09, 2023
Browse files
[GraphBolt] Enable `FusedCSCSamplingGraph.in_subgraph()` to accept dict (#6550)
parent
2d689235
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
5 deletions
+43
-5
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+39
-4
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+4
-1
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
b35757a0
...
...
@@ -283,7 +283,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""
return
self
.
_metadata
def
in_subgraph
(
self
,
nodes
:
torch
.
Tensor
)
->
torch
.
ScriptObject
:
def
in_subgraph
(
self
,
nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
)
->
FusedSampledSubgraphImpl
:
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
...
...
@@ -291,20 +293,53 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters
----------
nodes : torch.Tensor
The nodes to form the subgraph which are type agnostic.
nodes: torch.Tensor or Dict[str, torch.Tensor]
IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids.
- If `nodes` is a dictionary: The keys should be node type and
ids inside are heterogeneous ids.
Returns
-------
torch.classes.graphbolt.
SampledSubgraph
Fused
SampledSubgraph
Impl
The in subgraph.
Examples
--------
>>> import dgl.graphbolt as gb
>>> import torch
>>> total_num_nodes = 5
>>> total_num_edges = 12
>>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = {
... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs)
defaultdict(<class 'list'>, {
'N0:R0:N0': (tensor([]), tensor([])),
'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
'N1:R2:N0': (tensor([0, 1]), tensor([1, 1])),
'N1:R3:N1': (tensor([0, 1, 2]), tensor([1, 2, 2]))}
"""
if
isinstance
(
nodes
,
dict
):
nodes
=
self
.
_convert_to_homogeneous_nodes
(
nodes
)
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
# Ensure that there are no duplicate nodes.
assert
len
(
torch
.
unique
(
nodes
))
==
len
(
nodes
),
"Nodes cannot have duplicate values."
_in_subgraph
=
self
.
_c_csc_graph
.
in_subgraph
(
nodes
)
return
self
.
_convert_to_sampled_subgraph
(
_in_subgraph
)
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
b35757a0
...
...
@@ -562,7 +562,10 @@ def test_in_subgraph_heterogeneous():
)
# Extract in subgraph.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
{
"N0"
:
torch
.
LongTensor
([
1
]),
"N1"
:
torch
.
LongTensor
([
1
,
2
]),
}
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
# Verify in subgraph.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment