Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b3234516
Unverified
Commit
b3234516
authored
Jun 03, 2023
by
peizhou001
Committed by
GitHub
Jun 03, 2023
Browse files
[GraphBolt] Add sample etype neighbors (#5771)
parent
704aa423
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
41 deletions
+149
-41
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+52
-16
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+39
-4
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+31
-11
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+27
-10
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
b3234516
...
...
@@ -120,21 +120,28 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* subgraph.
*
* @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* or equal to the number of neighbors and replacement is false, in which case
* all the neighbors will be selected. Otherwise, it will pick the minimum
* number of neighbors between the fanout value and the total number of
* neighbors.
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
* neighbors of the node as a collective, regardless of the edge type.
* - Otherwise, the length should equal to the number of edge types, and
* each fanout value corresponds to a specific edge type of the node.
* The value of each fanout should be >= 0 or = -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
*
times.
Otherwise, each value can be selected only once.
* without replacement. If True, a value can be selected multiple
times.
* Otherwise, each value can be selected only once.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
*/
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
,
bool
replace
)
const
;
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
)
const
;
/**
* @brief Copy the graph to shared memory.
...
...
@@ -216,14 +223,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* or equal to the number of neighbors and replacement is false, in which case
* all the neighbors will be selected. Otherwise, it will pick the minimum
* number of neighbors between the fanout value and the total number of
* neighbors.
* >= 0 or -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
*
times.
Otherwise, each value can be selected only once.
* without replacement. If True, a value can be selected multiple
times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
*
* @return A tensor containing the picked neighbors.
...
...
@@ -232,6 +240,34 @@ torch::Tensor Pick(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
);
/**
* @brief Picks a specified number of neighbors for a node per edge type,
* starting from the given offset and having the specified number of neighbors.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanouts The edge sampling numbers corresponding to each edge type for
* a single node. The value of each fanout should be >= 0 or = 1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum threshold
* for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param type_per_edge Tensor representing the type of each edge in the
* original graph.
*
* @return A tensor containing the picked neighbors.
*/
torch
::
Tensor
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
);
}
// namespace sampling
}
// namespace graphbolt
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
b3234516
...
...
@@ -122,9 +122,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
}
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
,
bool
replace
)
const
{
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
bool
consider_etype
=
(
fanouts
.
size
()
>
1
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_node
(
num_nodes
);
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
zeros
({
num_nodes
+
1
},
indptr_
.
options
());
...
...
@@ -147,8 +150,14 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
continue
;
}
picked_neighbors_per_node
[
i
]
=
Pick
(
offset
,
num_neighbors
,
fanout
,
replace
,
indptr_
.
options
());
if
(
consider_etype
)
{
picked_neighbors_per_node
[
i
]
=
PickByEtype
(
offset
,
num_neighbors
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
.
value
());
}
else
{
picked_neighbors_per_node
[
i
]
=
Pick
(
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
indptr_
.
options
());
}
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_per_node
[
i
].
size
(
0
);
}
...
...
@@ -214,5 +223,31 @@ torch::Tensor Pick(
return
picked_neighbors
;
}
torch
::
Tensor
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
)
{
std
::
vector
<
torch
::
Tensor
>
picked_neighbors
(
fanouts
.
size
(),
torch
::
tensor
({},
options
));
int64_t
etype_begin
=
offset
;
int64_t
etype_end
=
offset
;
while
(
etype_end
<
offset
+
num_neighbors
)
{
int64_t
etype
=
type_per_edge
[
etype_end
].
item
<
int64_t
>
();
int64_t
fanout
=
fanouts
[
etype
];
while
(
etype_end
<
offset
+
num_neighbors
&&
type_per_edge
[
etype_end
].
item
<
int64_t
>
()
==
etype
)
{
etype_end
++
;
}
// Do sampling for one etype.
if
(
fanout
!=
0
)
{
picked_neighbors
[
etype
]
=
Pick
(
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
);
}
etype_begin
=
etype_end
;
}
return
torch
::
cat
(
picked_neighbors
,
0
);
}
}
// namespace sampling
}
// namespace graphbolt
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
b3234516
...
...
@@ -197,7 +197,7 @@ class CSCSamplingGraph:
def
sample_neighbors
(
self
,
nodes
:
torch
.
Tensor
,
fanout
:
int
,
fanout
s
:
torch
.
Tensor
,
replace
:
bool
=
False
,
)
->
torch
.
ScriptObject
:
"""Sample neighboring edges of the given nodes and return the induced
...
...
@@ -207,22 +207,42 @@ class CSCSamplingGraph:
----------
nodes: torch.Tensor
IDs of the given seed nodes.
fanout: int
The number of edges to be sampled for each node. It should be
>= 0 or -1. If -1 is given, it is equivalent to when the fanout
is greater or equal to the number of neighbors and replacement
is false, in which case all the neighbors will be selected.
Otherwise, it will pick the minimum number of neighbors between
the fanout value and the total number of neighbors.
replace: bool
fanouts: torch.Tensor
The number of edges to be sampled for each node with or without
considering edge types.
- When the length is 1, it indicates that the fanout applies to
all neighbors of the node as a collective, regardless of the
edge type.
- Otherwise, the length should equal to the number of edge
types, and each fanout value corresponds to a specific edge
type of the nodes.
The value of each fanout should be >= 0 or = -1.
- When the value is -1, all neighbors will be chosen for
sampling. It is equivalent to selecting all neighbors when
the fanout is >= the number of neighbors (and replacement
is set to false).
- When the value is a non-negative integer, it serves as a
minimum threshold for selecting neighbors.
replce: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
"""
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
fanout
>=
0
or
fanout
==
-
1
,
"Fanout shoud have value >= 0 or -1"
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanout
,
replace
)
assert
fanouts
.
dim
()
==
1
,
"Fanouts should be 1-D tensor."
if
fanouts
.
size
(
0
)
>
1
:
assert
(
self
.
type_per_edge
is
not
None
),
"To perform sampling for each edge type (when the length of
\
`fanouts` > 1), the graph must include edge type information."
assert
torch
.
all
(
(
fanouts
>=
0
)
|
(
fanouts
==
-
1
)
),
"Fanouts should consist of values that are either -1 or
\
greater than or equal to 0."
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanouts
.
tolist
(),
replace
)
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
"""Copy the graph to shared memory.
...
...
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
b3234516
...
...
@@ -402,21 +402,23 @@ def test_sample_neighbors():
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
)
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
fanout
=
-
1
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
)
fanout
s
=
torch
.
tensor
([
2
,
2
,
3
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
s
)
# Verify in subgraph.
assert
torch
.
equal
(
subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
assert
torch
.
equal
(
subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
torch
.
sort
(
subgraph
.
indices
)[
0
],
torch
.
sort
(
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
]))[
0
],
)
assert
torch
.
equal
(
subgraph
.
reverse_column_node_ids
,
nodes
)
assert
subgraph
.
reverse_row_node_ids
is
None
...
...
@@ -429,10 +431,21 @@ def test_sample_neighbors():
reason
=
"Graph is CPU only at present."
,
)
@
pytest
.
mark
.
parametrize
(
"fanout, expected_sampled_num"
,
[(
0
,
0
),
(
1
,
3
),
(
2
,
6
),
(
3
,
7
),
(
4
,
7
),
(
-
1
,
7
)],
"fanouts, expected_sampled_num"
,
[
([
0
],
0
),
([
1
],
3
),
([
2
],
6
),
([
4
],
7
),
([
-
1
],
7
),
([
0
,
0
],
0
),
([
1
,
0
],
3
),
([
1
,
1
],
6
),
([
2
,
2
],
7
),
([
-
1
,
-
1
],
7
),
],
)
def
test_sample_neighbors_fanout
(
fanout
,
expected_sampled_num
):
def
test_sample_neighbors_fanout
s
(
fanout
s
,
expected_sampled_num
):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
...
...
@@ -445,15 +458,17 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num):
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
)
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
)
fanouts
=
torch
.
LongTensor
(
fanouts
)
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
)
# Verify in subgraph.
sampled_num
=
subgraph
.
indices
.
size
(
0
)
...
...
@@ -488,7 +503,9 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
=
4
,
replace
=
replace
)
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
=
torch
.
LongTensor
([
4
]),
replace
=
replace
)
# Verify in subgraph.
sampled_num
=
subgraph
.
indices
.
size
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment