Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ae97049e
Unverified
Commit
ae97049e
authored
Jun 04, 2023
by
peizhou001
Committed by
GitHub
Jun 04, 2023
Browse files
[GraphBolt] Add return_eid for neighbor sampling (#5772)
parent
5f490d19
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
8 deletions
+18
-8
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+3
-1
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+5
-4
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+6
-1
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+4
-2
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
ae97049e
...
@@ -135,13 +135,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -135,13 +135,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* Otherwise, each value can be selected only once.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
*
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
* sampled graph's information.
*/
*/
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
)
const
;
bool
replace
,
bool
return_eids
)
const
;
/**
/**
* @brief Copy the graph to shared memory.
* @brief Copy the graph to shared memory.
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
ae97049e
...
@@ -123,7 +123,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
...
@@ -123,7 +123,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
)
const
{
bool
replace
,
bool
return_eids
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
// If true, perform sampling for each edge type of each node, otherwise just
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
// sample once for each node with no regard of edge types.
...
@@ -169,10 +169,11 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
...
@@ -169,10 +169,11 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch
::
Tensor
picked_eids
=
torch
::
cat
(
picked_neighbors_per_node
);
torch
::
Tensor
picked_eids
=
torch
::
cat
(
picked_neighbors_per_node
);
torch
::
Tensor
subgraph_indices
=
torch
::
Tensor
subgraph_indices
=
torch
::
index_select
(
indices_
,
0
,
picked_eids
);
torch
::
index_select
(
indices_
,
0
,
picked_eids
);
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
return
c10
::
make_intrusive
<
SampledSubgraph
>
(
return
c10
::
make_intrusive
<
SampledSubgraph
>
(
subgraph_indptr
,
subgraph_indices
,
nodes
,
torch
::
nullopt
,
torch
::
nullopt
,
subgraph_indptr
,
subgraph_indices
,
nodes
,
torch
::
nullopt
,
torch
::
nullopt
);
subgraph_reverse_edge_ids
,
torch
::
nullopt
);
}
}
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
...
...
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
ae97049e
...
@@ -199,6 +199,7 @@ class CSCSamplingGraph:
...
@@ -199,6 +199,7 @@ class CSCSamplingGraph:
nodes
:
torch
.
Tensor
,
nodes
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
return_eids
:
bool
=
False
,
)
->
torch
.
ScriptObject
:
)
->
torch
.
ScriptObject
:
"""Sample neighboring edges of the given nodes and return the induced
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
subgraph.
...
@@ -227,6 +228,10 @@ class CSCSamplingGraph:
...
@@ -227,6 +228,10 @@ class CSCSamplingGraph:
Boolean indicating whether the sample is preformed with or
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
times. Otherwise, each value can be selected only once.
return_eids: bool
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
typically used when edge features are required
"""
"""
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
...
@@ -241,7 +246,7 @@ class CSCSamplingGraph:
...
@@ -241,7 +246,7 @@ class CSCSamplingGraph:
),
"Fanouts should consist of values that are either -1 or
\
),
"Fanouts should consist of values that are either -1 or
\
greater than or equal to 0."
greater than or equal to 0."
return
self
.
_c_csc_graph
.
sample_neighbors
(
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanouts
.
tolist
(),
replace
nodes
,
fanouts
.
tolist
(),
replace
,
return_eids
)
)
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
...
...
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
ae97049e
...
@@ -412,7 +412,7 @@ def test_sample_neighbors():
...
@@ -412,7 +412,7 @@ def test_sample_neighbors():
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
fanouts
=
torch
.
tensor
([
2
,
2
,
3
])
fanouts
=
torch
.
tensor
([
2
,
2
,
3
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
)
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
,
return_eids
=
True
)
# Verify in subgraph.
# Verify in subgraph.
assert
torch
.
equal
(
subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
assert
torch
.
equal
(
subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
...
@@ -421,8 +421,10 @@ def test_sample_neighbors():
...
@@ -421,8 +421,10 @@ def test_sample_neighbors():
torch
.
sort
(
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
]))[
0
],
torch
.
sort
(
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
]))[
0
],
)
)
assert
torch
.
equal
(
subgraph
.
reverse_column_node_ids
,
nodes
)
assert
torch
.
equal
(
subgraph
.
reverse_column_node_ids
,
nodes
)
assert
torch
.
equal
(
subgraph
.
reverse_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
)
assert
subgraph
.
reverse_row_node_ids
is
None
assert
subgraph
.
reverse_row_node_ids
is
None
assert
subgraph
.
reverse_edge_ids
is
None
assert
subgraph
.
type_per_edge
is
None
assert
subgraph
.
type_per_edge
is
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment