Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b35757a0
Unverified
Commit
b35757a0
authored
Nov 09, 2023
by
Mingbang Wang
Committed by
GitHub
Nov 09, 2023
Browse files
[GraphBolt] Enable `FusedCSCSamplingGraph.in_subgraph()` to accept dict (#6550)
parent
2d689235
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
5 deletions
+43
-5
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+39
-4
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+4
-1
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
b35757a0
...
@@ -283,7 +283,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -283,7 +283,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""
"""
return
self
.
_metadata
return
self
.
_metadata
def
in_subgraph
(
self
,
nodes
:
torch
.
Tensor
)
->
torch
.
ScriptObject
:
def
in_subgraph
(
self
,
nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
)
->
FusedSampledSubgraphImpl
:
"""Return the subgraph induced on the inbound edges of the given nodes.
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
An in subgraph is equivalent to creating a new graph using the incoming
...
@@ -291,20 +293,53 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -291,20 +293,53 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters
Parameters
----------
----------
nodes : torch.Tensor
nodes: torch.Tensor or Dict[str, torch.Tensor]
The nodes to form the subgraph which are type agnostic.
IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids.
- If `nodes` is a dictionary: The keys should be node type and
ids inside are heterogeneous ids.
Returns
Returns
-------
-------
torch.classes.graphbolt.
SampledSubgraph
Fused
SampledSubgraph
Impl
The in subgraph.
The in subgraph.
Examples
--------
>>> import dgl.graphbolt as gb
>>> import torch
>>> total_num_nodes = 5
>>> total_num_edges = 12
>>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = {
... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs)
defaultdict(<class 'list'>, {
'N0:R0:N0': (tensor([]), tensor([])),
'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
'N1:R2:N0': (tensor([0, 1]), tensor([1, 1])),
'N1:R3:N1': (tensor([0, 1, 2]), tensor([1, 2, 2]))}
"""
"""
if
isinstance
(
nodes
,
dict
):
nodes
=
self
.
_convert_to_homogeneous_nodes
(
nodes
)
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
# Ensure that there are no duplicate nodes.
# Ensure that there are no duplicate nodes.
assert
len
(
torch
.
unique
(
nodes
))
==
len
(
assert
len
(
torch
.
unique
(
nodes
))
==
len
(
nodes
nodes
),
"Nodes cannot have duplicate values."
),
"Nodes cannot have duplicate values."
_in_subgraph
=
self
.
_c_csc_graph
.
in_subgraph
(
nodes
)
_in_subgraph
=
self
.
_c_csc_graph
.
in_subgraph
(
nodes
)
return
self
.
_convert_to_sampled_subgraph
(
_in_subgraph
)
return
self
.
_convert_to_sampled_subgraph
(
_in_subgraph
)
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
b35757a0
...
@@ -562,7 +562,10 @@ def test_in_subgraph_heterogeneous():
...
@@ -562,7 +562,10 @@ def test_in_subgraph_heterogeneous():
)
)
# Extract in subgraph.
# Extract in subgraph.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
{
"N0"
:
torch
.
LongTensor
([
1
]),
"N1"
:
torch
.
LongTensor
([
1
,
2
]),
}
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
# Verify in subgraph.
# Verify in subgraph.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment