Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
236ffa0f
Unverified
Commit
236ffa0f
authored
May 26, 2023
by
Rhett Ying
Committed by
GitHub
May 26, 2023
Browse files
[GraphBolt] add in_subgraph() API at the python level (#5743)
parent
2cb7c69d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
137 additions
and
2 deletions
+137
-2
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+1
-1
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+2
-1
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+24
-0
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+110
-0
No files found.
graphbolt/src/csc_sampling_graph.cc
View file @
236ffa0f
...
@@ -108,7 +108,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
...
@@ -108,7 +108,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
torch
::
zeros
({
nonzero_idx
.
size
(
0
)
+
1
},
indptr_
.
dtype
());
torch
::
zeros
({
nonzero_idx
.
size
(
0
)
+
1
},
indptr_
.
dtype
());
compact_indptr
.
index_put_
({
Slice
(
1
,
None
)},
indptr
.
index
({
nonzero_idx
}));
compact_indptr
.
index_put_
({
Slice
(
1
,
None
)},
indptr
.
index
({
nonzero_idx
}));
return
c10
::
make_intrusive
<
SampledSubgraph
>
(
return
c10
::
make_intrusive
<
SampledSubgraph
>
(
compact_indptr
.
cumsum
(
0
),
torch
::
cat
(
indices_arr
),
nonzero_idx
,
compact_indptr
.
cumsum
(
0
),
torch
::
cat
(
indices_arr
),
nonzero_idx
-
1
,
torch
::
arange
(
0
,
NumNodes
()),
torch
::
cat
(
edge_ids_arr
),
torch
::
arange
(
0
,
NumNodes
()),
torch
::
cat
(
edge_ids_arr
),
type_per_edge_
type_per_edge_
?
torch
::
optional
<
torch
::
Tensor
>
{
torch
::
cat
(
type_per_edge_arr
)}
?
torch
::
optional
<
torch
::
Tensor
>
{
torch
::
cat
(
type_per_edge_arr
)}
...
...
graphbolt/src/python_binding.cc
View file @
236ffa0f
...
@@ -27,7 +27,8 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -27,7 +27,8 @@ TORCH_LIBRARY(graphbolt, m) {
.
def
(
"csc_indptr"
,
&
CSCSamplingGraph
::
CSCIndptr
)
.
def
(
"csc_indptr"
,
&
CSCSamplingGraph
::
CSCIndptr
)
.
def
(
"indices"
,
&
CSCSamplingGraph
::
Indices
)
.
def
(
"indices"
,
&
CSCSamplingGraph
::
Indices
)
.
def
(
"node_type_offset"
,
&
CSCSamplingGraph
::
NodeTypeOffset
)
.
def
(
"node_type_offset"
,
&
CSCSamplingGraph
::
NodeTypeOffset
)
.
def
(
"type_per_edge"
,
&
CSCSamplingGraph
::
TypePerEdge
);
.
def
(
"type_per_edge"
,
&
CSCSamplingGraph
::
TypePerEdge
)
.
def
(
"in_subgraph"
,
&
CSCSamplingGraph
::
InSubgraph
);
m
.
def
(
"from_csc"
,
&
CSCSamplingGraph
::
FromCSC
);
m
.
def
(
"from_csc"
,
&
CSCSamplingGraph
::
FromCSC
);
m
.
def
(
"load_csc_sampling_graph"
,
&
LoadCSCSamplingGraph
);
m
.
def
(
"load_csc_sampling_graph"
,
&
LoadCSCSamplingGraph
);
m
.
def
(
"save_csc_sampling_graph"
,
&
SaveCSCSamplingGraph
);
m
.
def
(
"save_csc_sampling_graph"
,
&
SaveCSCSamplingGraph
);
...
...
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
236ffa0f
...
@@ -173,6 +173,30 @@ class CSCSamplingGraph:
...
@@ -173,6 +173,30 @@ class CSCSamplingGraph:
"""
"""
return
self
.
_metadata
return
self
.
_metadata
def
in_subgraph
(
self
,
nodes
:
torch
.
Tensor
)
->
torch
.
ScriptObject
:
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
edges of the given nodes.
Parameters
----------
nodes : torch.Tensor
The nodes to form the subgraph which are type agnostic.
Returns
-------
SampledSubgraph
The in subgraph.
"""
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
# Ensure that there are no duplicate nodes.
assert
len
(
torch
.
unique
(
nodes
))
==
len
(
nodes
),
"Nodes cannot have duplicate values."
return
self
.
_c_csc_graph
.
in_subgraph
(
nodes
)
def
from_csc
(
def
from_csc
(
csc_indptr
:
torch
.
Tensor
,
csc_indptr
:
torch
.
Tensor
,
...
...
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
236ffa0f
...
@@ -274,3 +274,113 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
...
@@ -274,3 +274,113 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
assert
torch
.
equal
(
graph
.
type_per_edge
,
graph2
.
type_per_edge
)
assert
torch
.
equal
(
graph
.
type_per_edge
,
graph2
.
type_per_edge
)
assert
graph
.
metadata
.
node_type_to_id
==
graph2
.
metadata
.
node_type_to_id
assert
graph
.
metadata
.
node_type_to_id
==
graph2
.
metadata
.
node_type_to_id
assert
graph
.
metadata
.
edge_type_to_id
==
graph2
.
metadata
.
edge_type_to_id
assert
graph
.
metadata
.
edge_type_to_id
==
graph2
.
metadata
.
edge_type_to_id
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
def
test_in_subgraph_homogeneous
():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes
=
5
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
# Extract in subgraph.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
# Verify in subgraph.
assert
torch
.
equal
(
in_subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
assert
torch
.
equal
(
in_subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
)
assert
torch
.
equal
(
in_subgraph
.
reverse_column_node_ids
,
nodes
)
assert
torch
.
equal
(
in_subgraph
.
reverse_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
)
assert
torch
.
equal
(
in_subgraph
.
reverse_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
)
assert
in_subgraph
.
type_per_edge
is
None
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
def
test_in_subgraph_heterogeneous
():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
node_type_0: [0, 1]
node_type_1: [2, 3, 4]
edge_type_0: node_type_0 -> node_type_0
edge_type_1: node_type_0 -> node_type_1
edge_type_2: node_type_1 -> node_type_0
edge_type_3: node_type_1 -> node_type_1
"""
# Initialize data.
num_nodes
=
5
num_edges
=
12
ntypes
=
{
"N0"
:
0
,
"N1"
:
1
,
}
etypes
=
{
(
"N0"
,
"R0"
,
"N0"
):
0
,
(
"N0"
,
"R1"
,
"N1"
):
1
,
(
"N1"
,
"R2"
,
"N0"
):
2
,
(
"N1"
,
"R3"
,
"N1"
):
3
,
}
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
node_type_offset
=
torch
.
LongTensor
([
0
,
2
,
5
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
2
,
2
,
2
,
1
,
1
,
1
,
3
,
1
,
3
,
3
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
node_type_offset
[
-
1
]
==
num_nodes
assert
all
(
type_per_edge
<
len
(
etypes
))
# Construct CSCSamplingGraph.
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
)
# Extract in subgraph.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
# Verify in subgraph.
assert
torch
.
equal
(
in_subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
assert
torch
.
equal
(
in_subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
)
assert
torch
.
equal
(
in_subgraph
.
reverse_column_node_ids
,
nodes
)
assert
torch
.
equal
(
in_subgraph
.
reverse_row_node_ids
,
torch
.
arange
(
0
,
num_nodes
)
)
assert
torch
.
equal
(
in_subgraph
.
reverse_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
)
assert
torch
.
equal
(
in_subgraph
.
type_per_edge
,
torch
.
LongTensor
([
2
,
2
,
1
,
3
,
1
,
3
,
3
])
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment