Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b3234516
You need to sign in or sign up before continuing.
Unverified
Commit
b3234516
authored
Jun 03, 2023
by
peizhou001
Committed by
GitHub
Jun 03, 2023
Browse files
[GraphBolt] Add sample etype neighbors (#5771)
parent
704aa423
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
41 deletions
+149
-41
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+52
-16
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+39
-4
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+31
-11
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+27
-10
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
b3234516
...
@@ -120,21 +120,28 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -120,21 +120,28 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* subgraph.
* subgraph.
*
*
* @param nodes The nodes from which to sample neighbors.
* @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be
* @param fanouts The number of edges to be sampled for each node with or
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* without considering edge types.
* or equal to the number of neighbors and replacement is false, in which case
* - When the length is 1, it indicates that the fanout applies to all
* all the neighbors will be selected. Otherwise, it will pick the minimum
* neighbors of the node as a collective, regardless of the edge type.
* number of neighbors between the fanout value and the total number of
* - Otherwise, the length should equal to the number of edge types, and
* neighbors.
* each fanout value corresponds to a specific edge type of the node.
* The value of each fanout should be >= 0 or = -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* without replacement. If True, a value can be selected multiple
times.
*
times.
Otherwise, each value can be selected only once.
* Otherwise, each value can be selected only once.
*
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
* sampled graph's information.
*/
*/
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
,
bool
replace
)
const
;
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
)
const
;
/**
/**
* @brief Copy the graph to shared memory.
* @brief Copy the graph to shared memory.
...
@@ -216,14 +223,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -216,14 +223,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* node.
* node.
* @param num_neighbors The number of neighbors to pick.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* >= 0 or -1.
* or equal to the number of neighbors and replacement is false, in which case
* - When the value is -1, all neighbors will be chosen for sampling. It is
* all the neighbors will be selected. Otherwise, it will pick the minimum
* equivalent to selecting all neighbors when the fanout is >= the number of
* number of neighbors between the fanout value and the total number of
* neighbors (and replacement is set to false).
* neighbors.
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* without replacement. If True, a value can be selected multiple
times.
*
times.
Otherwise, each value can be selected only once.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param options Tensor options specifying the desired data type of the result.
*
*
* @return A tensor containing the picked neighbors.
* @return A tensor containing the picked neighbors.
...
@@ -232,6 +240,34 @@ torch::Tensor Pick(
...
@@ -232,6 +240,34 @@ torch::Tensor Pick(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
);
const
torch
::
TensorOptions
&
options
);
/**
* @brief Picks a specified number of neighbors for a node per edge type,
* starting from the given offset and having the specified number of neighbors.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanouts The edge sampling numbers corresponding to each edge type for
* a single node. The value of each fanout should be >= 0 or = 1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum threshold
* for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param type_per_edge Tensor representing the type of each edge in the
* original graph.
*
* @return A tensor containing the picked neighbors.
*/
torch
::
Tensor
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
);
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
b3234516
...
@@ -122,9 +122,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
...
@@ -122,9 +122,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
}
}
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
,
bool
replace
)
const
{
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
bool
consider_etype
=
(
fanouts
.
size
()
>
1
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_node
(
num_nodes
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_node
(
num_nodes
);
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
zeros
({
num_nodes
+
1
},
indptr_
.
options
());
torch
::
zeros
({
num_nodes
+
1
},
indptr_
.
options
());
...
@@ -147,8 +150,14 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
...
@@ -147,8 +150,14 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
continue
;
continue
;
}
}
picked_neighbors_per_node
[
i
]
=
if
(
consider_etype
)
{
Pick
(
offset
,
num_neighbors
,
fanout
,
replace
,
indptr_
.
options
());
picked_neighbors_per_node
[
i
]
=
PickByEtype
(
offset
,
num_neighbors
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
.
value
());
}
else
{
picked_neighbors_per_node
[
i
]
=
Pick
(
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
indptr_
.
options
());
}
num_picked_neighbors_per_node
[
i
+
1
]
=
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_per_node
[
i
].
size
(
0
);
picked_neighbors_per_node
[
i
].
size
(
0
);
}
}
...
@@ -214,5 +223,31 @@ torch::Tensor Pick(
...
@@ -214,5 +223,31 @@ torch::Tensor Pick(
return
picked_neighbors
;
return
picked_neighbors
;
}
}
torch
::
Tensor
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
)
{
std
::
vector
<
torch
::
Tensor
>
picked_neighbors
(
fanouts
.
size
(),
torch
::
tensor
({},
options
));
int64_t
etype_begin
=
offset
;
int64_t
etype_end
=
offset
;
while
(
etype_end
<
offset
+
num_neighbors
)
{
int64_t
etype
=
type_per_edge
[
etype_end
].
item
<
int64_t
>
();
int64_t
fanout
=
fanouts
[
etype
];
while
(
etype_end
<
offset
+
num_neighbors
&&
type_per_edge
[
etype_end
].
item
<
int64_t
>
()
==
etype
)
{
etype_end
++
;
}
// Do sampling for one etype.
if
(
fanout
!=
0
)
{
picked_neighbors
[
etype
]
=
Pick
(
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
);
}
etype_begin
=
etype_end
;
}
return
torch
::
cat
(
picked_neighbors
,
0
);
}
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
b3234516
...
@@ -197,7 +197,7 @@ class CSCSamplingGraph:
...
@@ -197,7 +197,7 @@ class CSCSamplingGraph:
def
sample_neighbors
(
def
sample_neighbors
(
self
,
self
,
nodes
:
torch
.
Tensor
,
nodes
:
torch
.
Tensor
,
fanout
:
int
,
fanout
s
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
)
->
torch
.
ScriptObject
:
)
->
torch
.
ScriptObject
:
"""Sample neighboring edges of the given nodes and return the induced
"""Sample neighboring edges of the given nodes and return the induced
...
@@ -207,22 +207,42 @@ class CSCSamplingGraph:
...
@@ -207,22 +207,42 @@ class CSCSamplingGraph:
----------
----------
nodes: torch.Tensor
nodes: torch.Tensor
IDs of the given seed nodes.
IDs of the given seed nodes.
fanout: int
fanouts: torch.Tensor
The number of edges to be sampled for each node. It should be
The number of edges to be sampled for each node with or without
>= 0 or -1. If -1 is given, it is equivalent to when the fanout
considering edge types.
is greater or equal to the number of neighbors and replacement
- When the length is 1, it indicates that the fanout applies to
is false, in which case all the neighbors will be selected.
all neighbors of the node as a collective, regardless of the
Otherwise, it will pick the minimum number of neighbors between
edge type.
the fanout value and the total number of neighbors.
- Otherwise, the length should equal to the number of edge
replace: bool
types, and each fanout value corresponds to a specific edge
type of the nodes.
The value of each fanout should be >= 0 or = -1.
- When the value is -1, all neighbors will be chosen for
sampling. It is equivalent to selecting all neighbors when
the fanout is >= the number of neighbors (and replacement
is set to false).
- When the value is a non-negative integer, it serves as a
minimum threshold for selecting neighbors.
replce: bool
Boolean indicating whether the sample is preformed with or
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
times. Otherwise, each value can be selected only once.
"""
"""
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
fanout
>=
0
or
fanout
==
-
1
,
"Fanout shoud have value >= 0 or -1"
assert
fanouts
.
dim
()
==
1
,
"Fanouts should be 1-D tensor."
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanout
,
replace
)
if
fanouts
.
size
(
0
)
>
1
:
assert
(
self
.
type_per_edge
is
not
None
),
"To perform sampling for each edge type (when the length of
\
`fanouts` > 1), the graph must include edge type information."
assert
torch
.
all
(
(
fanouts
>=
0
)
|
(
fanouts
==
-
1
)
),
"Fanouts should consist of values that are either -1 or
\
greater than or equal to 0."
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanouts
.
tolist
(),
replace
)
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
"""Copy the graph to shared memory.
"""Copy the graph to shared memory.
...
...
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
b3234516
...
@@ -402,21 +402,23 @@ def test_sample_neighbors():
...
@@ -402,21 +402,23 @@ def test_sample_neighbors():
num_edges
=
12
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct CSCSamplingGraph.
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
)
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
fanout
=
-
1
fanout
s
=
torch
.
tensor
([
2
,
2
,
3
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
)
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
s
)
# Verify in subgraph.
# Verify in subgraph.
assert
torch
.
equal
(
subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
assert
torch
.
equal
(
subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
assert
torch
.
equal
(
assert
torch
.
equal
(
subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
torch
.
sort
(
subgraph
.
indices
)[
0
],
torch
.
sort
(
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
]))[
0
],
)
)
assert
torch
.
equal
(
subgraph
.
reverse_column_node_ids
,
nodes
)
assert
torch
.
equal
(
subgraph
.
reverse_column_node_ids
,
nodes
)
assert
subgraph
.
reverse_row_node_ids
is
None
assert
subgraph
.
reverse_row_node_ids
is
None
...
@@ -429,10 +431,21 @@ def test_sample_neighbors():
...
@@ -429,10 +431,21 @@ def test_sample_neighbors():
reason
=
"Graph is CPU only at present."
,
reason
=
"Graph is CPU only at present."
,
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"fanout, expected_sampled_num"
,
"fanouts, expected_sampled_num"
,
[(
0
,
0
),
(
1
,
3
),
(
2
,
6
),
(
3
,
7
),
(
4
,
7
),
(
-
1
,
7
)],
[
([
0
],
0
),
([
1
],
3
),
([
2
],
6
),
([
4
],
7
),
([
-
1
],
7
),
([
0
,
0
],
0
),
([
1
,
0
],
3
),
([
1
,
1
],
6
),
([
2
,
2
],
7
),
([
-
1
,
-
1
],
7
),
],
)
)
def
test_sample_neighbors_fanout
(
fanout
,
expected_sampled_num
):
def
test_sample_neighbors_fanout
s
(
fanout
s
,
expected_sampled_num
):
"""Original graph in COO:
"""Original graph in COO:
1 0 1 0 1
1 0 1 0 1
1 0 1 1 0
1 0 1 1 0
...
@@ -445,15 +458,17 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num):
...
@@ -445,15 +458,17 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num):
num_edges
=
12
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct CSCSamplingGraph.
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
)
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
)
fanouts
=
torch
.
LongTensor
(
fanouts
)
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
)
# Verify in subgraph.
# Verify in subgraph.
sampled_num
=
subgraph
.
indices
.
size
(
0
)
sampled_num
=
subgraph
.
indices
.
size
(
0
)
...
@@ -488,7 +503,9 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
...
@@ -488,7 +503,9 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
=
4
,
replace
=
replace
)
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
=
torch
.
LongTensor
([
4
]),
replace
=
replace
)
# Verify in subgraph.
# Verify in subgraph.
sampled_num
=
subgraph
.
indices
.
size
(
0
)
sampled_num
=
subgraph
.
indices
.
size
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment