Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704aa423
Unverified
Commit
704aa423
authored
Jun 02, 2023
by
peizhou001
Committed by
GitHub
Jun 02, 2023
Browse files
[GraphBolt] Add replace for neighbor sampling (#5770)
[Graphbolt] Add replace for sampling
parent
c9c165f7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
17 deletions
+81
-17
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+19
-8
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+11
-6
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+12
-3
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+39
-0
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
704aa423
...
...
@@ -121,15 +121,20 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*
* @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* or equal to the number of neighbors and replacement is false, in which case
* all the neighbors will be selected. Otherwise, it will pick the minimum
* number of neighbors between the fanout value and the total number of
* neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* times.Otherwise, each value can be selected only once.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
*/
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
)
const
;
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
,
bool
replace
)
const
;
/**
* @brief Copy the graph to shared memory.
...
...
@@ -211,14 +216,20 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* or equal to the number of neighbors and replacement is false, in which case
* all the neighbors will be selected. Otherwise, it will pick the minimum
* number of neighbors between the fanout value and the total number of
* neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* times.Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
*
* @return A tensor containing the picked neighbors.
*/
torch
::
Tensor
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
);
}
// namespace sampling
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
704aa423
...
...
@@ -122,7 +122,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
}
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
)
const
{
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
,
bool
replace
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_node
(
num_nodes
);
...
...
@@ -148,7 +148,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
}
picked_neighbors_per_node
[
i
]
=
Pick
(
offset
,
num_neighbors
,
fanout
,
indptr_
.
options
());
Pick
(
offset
,
num_neighbors
,
fanout
,
replace
,
indptr_
.
options
());
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_per_node
[
i
].
size
(
0
);
}
...
...
@@ -197,14 +197,19 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
}
torch
::
Tensor
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
)
{
torch
::
Tensor
picked_neighbors
;
if
((
fanout
==
-
1
)
||
(
num_neighbors
<=
fanout
))
{
if
((
fanout
==
-
1
)
||
(
num_neighbors
<=
fanout
&&
!
replace
))
{
picked_neighbors
=
torch
::
arange
(
offset
,
offset
+
num_neighbors
,
options
);
}
else
{
picked_neighbors
=
torch
::
randperm
(
num_neighbors
)
+
offset
;
picked_neighbors
=
picked_neighbors
.
slice
(
0
,
0
,
fanout
);
if
(
replace
)
{
picked_neighbors
=
torch
::
randint
(
offset
,
offset
+
num_neighbors
,
{
fanout
},
options
);
}
else
{
picked_neighbors
=
torch
::
randperm
(
num_neighbors
,
options
)
+
offset
;
picked_neighbors
=
picked_neighbors
.
slice
(
0
,
0
,
fanout
);
}
}
return
picked_neighbors
;
}
...
...
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
704aa423
...
...
@@ -195,7 +195,10 @@ class CSCSamplingGraph:
return
self
.
_c_csc_graph
.
in_subgraph
(
nodes
)
def
sample_neighbors
(
self
,
nodes
:
torch
.
Tensor
,
fanout
:
int
self
,
nodes
:
torch
.
Tensor
,
fanout
:
int
,
replace
:
bool
=
False
,
)
->
torch
.
ScriptObject
:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
...
...
@@ -206,14 +209,20 @@ class CSCSamplingGraph:
IDs of the given seed nodes.
fanout: int
The number of edges to be sampled for each node. It should be
>= 0 or -1. If -1 is given, all neighbors will be selected.
>= 0 or -1. If -1 is given, it is equivalent to when the fanout
is greater or equal to the number of neighbors and replacement
is false, in which case all the neighbors will be selected.
Otherwise, it will pick the minimum number of neighbors between
the fanout value and the total number of neighbors.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
"""
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
fanout
>=
0
or
fanout
==
-
1
,
"Fanout shoud have value >= 0 or -1"
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanout
)
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanout
,
replace
)
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
"""Copy the graph to shared memory.
...
...
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
704aa423
...
...
@@ -460,6 +460,41 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num):
assert
sampled_num
==
expected_sampled_num
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
@
pytest
.
mark
.
parametrize
(
"replace, expected_sampled_num"
,
[(
False
,
7
),
(
True
,
12
)]
)
def
test_sample_neighbors_replace
(
replace
,
expected_sampled_num
):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes
=
5
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
=
4
,
replace
=
replace
)
# Verify in subgraph.
sampled_num
=
subgraph
.
indices
.
size
(
0
)
assert
sampled_num
==
expected_sampled_num
def
check_tensors_on_the_same_shared_memory
(
t1
:
torch
.
Tensor
,
t2
:
torch
.
Tensor
):
"""Check if two tensors are on the same shared memory.
...
...
@@ -574,3 +609,7 @@ def test_hetero_graph_on_shared_memory(
assert
metadata
.
edge_type_to_id
==
graph1
.
metadata
.
edge_type_to_id
assert
metadata
.
node_type_to_id
==
graph2
.
metadata
.
node_type_to_id
assert
metadata
.
edge_type_to_id
==
graph2
.
metadata
.
edge_type_to_id
if
__name__
==
"__main__"
:
test_sample_neighbors_replace
(
True
,
12
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment