Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c9c165f7
Unverified
Commit
c9c165f7
authored
Jun 02, 2023
by
peizhou001
Committed by
GitHub
Jun 02, 2023
Browse files
[GraphBolt] Add fanout for neighbor sampling (#5768)
[Graphbolt] Add fanout for sampling
parent
a99095e7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
88 additions
and
7 deletions
+88
-7
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+23
-1
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+17
-3
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+10
-2
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+38
-1
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
c9c165f7
...
@@ -120,12 +120,16 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -120,12 +120,16 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* subgraph.
* subgraph.
*
*
* @param nodes The nodes from which to sample neighbors.
* @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
*
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
* sampled graph's information.
*/
*/
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
)
const
;
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
)
const
;
/**
/**
* @brief Copy the graph to shared memory.
* @brief Copy the graph to shared memory.
...
@@ -199,6 +203,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -199,6 +203,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
SharedMemoryPtr
tensor_meta_shm_
,
tensor_data_shm_
;
SharedMemoryPtr
tensor_meta_shm_
,
tensor_data_shm_
;
};
};
/**
* @brief Picks a specified number of neighbors for a node, starting from the
* given offset and having the specified number of neighbors.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
* @param options Tensor options specifying the desired data type of the result.
* @return A tensor containing the picked neighbors.
*/
torch
::
Tensor
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
);
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
c9c165f7
...
@@ -122,7 +122,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
...
@@ -122,7 +122,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
}
}
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
)
const
{
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_node
(
num_nodes
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_node
(
num_nodes
);
...
@@ -148,8 +148,9 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
...
@@ -148,8 +148,9 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
}
}
picked_neighbors_per_node
[
i
]
=
picked_neighbors_per_node
[
i
]
=
torch
::
arange
(
offset
,
offset
+
num_neighbors
);
Pick
(
offset
,
num_neighbors
,
fanout
,
indptr_
.
options
());
num_picked_neighbors_per_node
[
i
+
1
]
=
num_neighbors
;
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_per_node
[
i
].
size
(
0
);
}
}
});
// End of the thread.
});
// End of the thread.
...
@@ -195,5 +196,18 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
...
@@ -195,5 +196,18 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
return
BuildGraphFromSharedMemoryTensors
(
std
::
move
(
shared_memory_tensors
));
return
BuildGraphFromSharedMemoryTensors
(
std
::
move
(
shared_memory_tensors
));
}
}
torch
::
Tensor
Pick
(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
const
torch
::
TensorOptions
&
options
)
{
torch
::
Tensor
picked_neighbors
;
if
((
fanout
==
-
1
)
||
(
num_neighbors
<=
fanout
))
{
picked_neighbors
=
torch
::
arange
(
offset
,
offset
+
num_neighbors
,
options
);
}
else
{
picked_neighbors
=
torch
::
randperm
(
num_neighbors
)
+
offset
;
picked_neighbors
=
picked_neighbors
.
slice
(
0
,
0
,
fanout
);
}
return
picked_neighbors
;
}
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
c9c165f7
...
@@ -194,7 +194,9 @@ class CSCSamplingGraph:
...
@@ -194,7 +194,9 @@ class CSCSamplingGraph:
),
"Nodes cannot have duplicate values."
),
"Nodes cannot have duplicate values."
return
self
.
_c_csc_graph
.
in_subgraph
(
nodes
)
return
self
.
_c_csc_graph
.
in_subgraph
(
nodes
)
def
sample_neighbors
(
self
,
nodes
:
torch
.
Tensor
)
->
torch
.
ScriptObject
:
def
sample_neighbors
(
self
,
nodes
:
torch
.
Tensor
,
fanout
:
int
)
->
torch
.
ScriptObject
:
"""Sample neighboring edges of the given nodes and return the induced
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
subgraph.
...
@@ -202,10 +204,16 @@ class CSCSamplingGraph:
...
@@ -202,10 +204,16 @@ class CSCSamplingGraph:
----------
----------
nodes: torch.Tensor
nodes: torch.Tensor
IDs of the given seed nodes.
IDs of the given seed nodes.
fanout: int
The number of edges to be sampled for each node. It should be
>= 0 or -1. If -1 is given, all neighbors will be selected.
Otherwise, it will pick the minimum number of neighbors between
the fanout value and the total number of neighbors.
"""
"""
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
)
assert
fanout
>=
0
or
fanout
==
-
1
,
"Fanout shoud have value >= 0 or -1"
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanout
)
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
"""Copy the graph to shared memory.
"""Copy the graph to shared memory.
...
...
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
c9c165f7
...
@@ -410,7 +410,8 @@ def test_sample_neighbors():
...
@@ -410,7 +410,8 @@ def test_sample_neighbors():
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
)
fanout
=
-
1
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
)
# Verify in subgraph.
# Verify in subgraph.
assert
torch
.
equal
(
subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
assert
torch
.
equal
(
subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
...
@@ -423,6 +424,42 @@ def test_sample_neighbors():
...
@@ -423,6 +424,42 @@ def test_sample_neighbors():
assert
subgraph
.
type_per_edge
is
None
assert
subgraph
.
type_per_edge
is
None
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
@
pytest
.
mark
.
parametrize
(
"fanout, expected_sampled_num"
,
[(
0
,
0
),
(
1
,
3
),
(
2
,
6
),
(
3
,
7
),
(
4
,
7
),
(
-
1
,
7
)],
)
def
test_sample_neighbors_fanout
(
fanout
,
expected_sampled_num
):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes
=
5
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
)
# Verify in subgraph.
sampled_num
=
subgraph
.
indices
.
size
(
0
)
assert
sampled_num
==
expected_sampled_num
def
check_tensors_on_the_same_shared_memory
(
t1
:
torch
.
Tensor
,
t2
:
torch
.
Tensor
):
def
check_tensors_on_the_same_shared_memory
(
t1
:
torch
.
Tensor
,
t2
:
torch
.
Tensor
):
"""Check if two tensors are on the same shared memory.
"""Check if two tensors are on the same shared memory.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment