Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b3234516
Unverified
Commit
b3234516
authored
Jun 03, 2023
by
peizhou001
Committed by
GitHub
Jun 03, 2023
Browse files
[GraphBolt] Add sample etype neighbors (#5771)
parent
704aa423
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
41 deletions
+149
-41
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+52
-16
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+39
-4
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+31
-11
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+27
-10
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
b3234516
...
@@ -120,21 +120,28 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -120,21 +120,28 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* subgraph.
* subgraph.
*
*
* @param nodes The nodes from which to sample neighbors.
* @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be
* @param fanouts The number of edges to be sampled for each node with or
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* without considering edge types.
* or equal to the number of neighbors and replacement is false, in which case
* - When the length is 1, it indicates that the fanout applies to all
* all the neighbors will be selected. Otherwise, it will pick the minimum
* neighbors of the node as a collective, regardless of the edge type.
* number of neighbors between the fanout value and the total number of
* - Otherwise, the length should equal to the number of edge types, and
* neighbors.
* each fanout value corresponds to a specific edge type of the node.
* The value of each fanout should be >= 0 or = -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* without replacement. If True, a value can be selected multiple
times.
*
times.
Otherwise, each value can be selected only once.
* Otherwise, each value can be selected only once.
*
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
* sampled graph's information.
*/
*/
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
,
bool
replace
)
const
;
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
)
const
;
/**
/**
* @brief Copy the graph to shared memory.
* @brief Copy the graph to shared memory.
...
@@ -216,14 +223,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -216,14 +223,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* node.
* node.
* @param num_neighbors The number of neighbors to pick.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* >= 0 or -1.
* or equal to the number of neighbors and replacement is false, in which case
* - When the value is -1, all neighbors will be chosen for sampling. It is
* all the neighbors will be selected. Otherwise, it will pick the minimum
* equivalent to selecting all neighbors when the fanout is >= the number of
* number of neighbors between the fanout value and the total number of
* neighbors (and replacement is set to false).
* neighbors.
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* without replacement. If True, a value can be selected multiple
times.
*
times.
Otherwise, each value can be selected only once.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param options Tensor options specifying the desired data type of the result.
*
*
* @return A tensor containing the picked neighbors.
* @return A tensor containing the picked neighbors.
...
@@ -232,6 +240,34 @@ torch::Tensor Pick(
...
@@ -232,6 +240,34 @@ torch::Tensor Pick(
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
int64_t
offset
,
int64_t
num_neighbors
,
int64_t
fanout
,
bool
replace
,
const
torch
::
TensorOptions
&
options
);
const
torch
::
TensorOptions
&
options
);
/**
* @brief Picks a specified number of neighbors for a node per edge type,
* starting from the given offset and having the specified number of neighbors.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanouts The edge sampling numbers corresponding to each edge type for
* a single node. The value of each fanout should be >= 0 or = 1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum threshold
* for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param type_per_edge Tensor representing the type of each edge in the
* original graph.
*
* @return A tensor containing the picked neighbors.
*/
torch
::
Tensor
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
);
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
b3234516
...
@@ -122,9 +122,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
...
@@ -122,9 +122,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
}
}
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
int64_t
fanout
,
bool
replace
)
const
{
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
bool
consider_etype
=
(
fanouts
.
size
()
>
1
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_node
(
num_nodes
);
std
::
vector
<
torch
::
Tensor
>
picked_neighbors_per_node
(
num_nodes
);
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
zeros
({
num_nodes
+
1
},
indptr_
.
options
());
torch
::
zeros
({
num_nodes
+
1
},
indptr_
.
options
());
...
@@ -147,8 +150,14 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
...
@@ -147,8 +150,14 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
continue
;
continue
;
}
}
picked_neighbors_per_node
[
i
]
=
if
(
consider_etype
)
{
Pick
(
offset
,
num_neighbors
,
fanout
,
replace
,
indptr_
.
options
());
picked_neighbors_per_node
[
i
]
=
PickByEtype
(
offset
,
num_neighbors
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
.
value
());
}
else
{
picked_neighbors_per_node
[
i
]
=
Pick
(
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
indptr_
.
options
());
}
num_picked_neighbors_per_node
[
i
+
1
]
=
num_picked_neighbors_per_node
[
i
+
1
]
=
picked_neighbors_per_node
[
i
].
size
(
0
);
picked_neighbors_per_node
[
i
].
size
(
0
);
}
}
...
@@ -214,5 +223,31 @@ torch::Tensor Pick(
...
@@ -214,5 +223,31 @@ torch::Tensor Pick(
return
picked_neighbors
;
return
picked_neighbors
;
}
}
torch
::
Tensor
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
)
{
std
::
vector
<
torch
::
Tensor
>
picked_neighbors
(
fanouts
.
size
(),
torch
::
tensor
({},
options
));
int64_t
etype_begin
=
offset
;
int64_t
etype_end
=
offset
;
while
(
etype_end
<
offset
+
num_neighbors
)
{
int64_t
etype
=
type_per_edge
[
etype_end
].
item
<
int64_t
>
();
int64_t
fanout
=
fanouts
[
etype
];
while
(
etype_end
<
offset
+
num_neighbors
&&
type_per_edge
[
etype_end
].
item
<
int64_t
>
()
==
etype
)
{
etype_end
++
;
}
// Do sampling for one etype.
if
(
fanout
!=
0
)
{
picked_neighbors
[
etype
]
=
Pick
(
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
);
}
etype_begin
=
etype_end
;
}
return
torch
::
cat
(
picked_neighbors
,
0
);
}
}
// namespace sampling
}
// namespace sampling
}
// namespace graphbolt
}
// namespace graphbolt
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
b3234516
...
@@ -197,7 +197,7 @@ class CSCSamplingGraph:
...
@@ -197,7 +197,7 @@ class CSCSamplingGraph:
def
sample_neighbors
(
def
sample_neighbors
(
self
,
self
,
nodes
:
torch
.
Tensor
,
nodes
:
torch
.
Tensor
,
fanout
:
int
,
fanout
s
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
)
->
torch
.
ScriptObject
:
)
->
torch
.
ScriptObject
:
"""Sample neighboring edges of the given nodes and return the induced
"""Sample neighboring edges of the given nodes and return the induced
...
@@ -207,22 +207,42 @@ class CSCSamplingGraph:
...
@@ -207,22 +207,42 @@ class CSCSamplingGraph:
----------
----------
nodes: torch.Tensor
nodes: torch.Tensor
IDs of the given seed nodes.
IDs of the given seed nodes.
fanout: int
fanouts: torch.Tensor
The number of edges to be sampled for each node. It should be
The number of edges to be sampled for each node with or without
>= 0 or -1. If -1 is given, it is equivalent to when the fanout
considering edge types.
is greater or equal to the number of neighbors and replacement
- When the length is 1, it indicates that the fanout applies to
is false, in which case all the neighbors will be selected.
all neighbors of the node as a collective, regardless of the
Otherwise, it will pick the minimum number of neighbors between
edge type.
the fanout value and the total number of neighbors.
- Otherwise, the length should equal to the number of edge
replace: bool
types, and each fanout value corresponds to a specific edge
type of the nodes.
The value of each fanout should be >= 0 or = -1.
- When the value is -1, all neighbors will be chosen for
sampling. It is equivalent to selecting all neighbors when
the fanout is >= the number of neighbors (and replacement
is set to false).
- When the value is a non-negative integer, it serves as a
minimum threshold for selecting neighbors.
replce: bool
Boolean indicating whether the sample is preformed with or
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
times. Otherwise, each value can be selected only once.
"""
"""
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
fanout
>=
0
or
fanout
==
-
1
,
"Fanout shoud have value >= 0 or -1"
assert
fanouts
.
dim
()
==
1
,
"Fanouts should be 1-D tensor."
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanout
,
replace
)
if
fanouts
.
size
(
0
)
>
1
:
assert
(
self
.
type_per_edge
is
not
None
),
"To perform sampling for each edge type (when the length of
\
`fanouts` > 1), the graph must include edge type information."
assert
torch
.
all
(
(
fanouts
>=
0
)
|
(
fanouts
==
-
1
)
),
"Fanouts should consist of values that are either -1 or
\
greater than or equal to 0."
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanouts
.
tolist
(),
replace
)
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
def
copy_to_shared_memory
(
self
,
shared_memory_name
:
str
):
"""Copy the graph to shared memory.
"""Copy the graph to shared memory.
...
...
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
b3234516
...
@@ -402,21 +402,23 @@ def test_sample_neighbors():
...
@@ -402,21 +402,23 @@ def test_sample_neighbors():
num_edges
=
12
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct CSCSamplingGraph.
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
)
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
fanout
=
-
1
fanout
s
=
torch
.
tensor
([
2
,
2
,
3
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
)
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
s
)
# Verify in subgraph.
# Verify in subgraph.
assert
torch
.
equal
(
subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
assert
torch
.
equal
(
subgraph
.
indptr
,
torch
.
LongTensor
([
0
,
2
,
4
,
7
]))
assert
torch
.
equal
(
assert
torch
.
equal
(
subgraph
.
indices
,
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
])
torch
.
sort
(
subgraph
.
indices
)[
0
],
torch
.
sort
(
torch
.
LongTensor
([
2
,
3
,
1
,
2
,
0
,
3
,
4
]))[
0
],
)
)
assert
torch
.
equal
(
subgraph
.
reverse_column_node_ids
,
nodes
)
assert
torch
.
equal
(
subgraph
.
reverse_column_node_ids
,
nodes
)
assert
subgraph
.
reverse_row_node_ids
is
None
assert
subgraph
.
reverse_row_node_ids
is
None
...
@@ -429,10 +431,21 @@ def test_sample_neighbors():
...
@@ -429,10 +431,21 @@ def test_sample_neighbors():
reason
=
"Graph is CPU only at present."
,
reason
=
"Graph is CPU only at present."
,
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"fanout, expected_sampled_num"
,
"fanouts, expected_sampled_num"
,
[(
0
,
0
),
(
1
,
3
),
(
2
,
6
),
(
3
,
7
),
(
4
,
7
),
(
-
1
,
7
)],
[
([
0
],
0
),
([
1
],
3
),
([
2
],
6
),
([
4
],
7
),
([
-
1
],
7
),
([
0
,
0
],
0
),
([
1
,
0
],
3
),
([
1
,
1
],
6
),
([
2
,
2
],
7
),
([
-
1
,
-
1
],
7
),
],
)
)
def
test_sample_neighbors_fanout
(
fanout
,
expected_sampled_num
):
def
test_sample_neighbors_fanout
s
(
fanout
s
,
expected_sampled_num
):
"""Original graph in COO:
"""Original graph in COO:
1 0 1 0 1
1 0 1 0 1
1 0 1 1 0
1 0 1 1 0
...
@@ -445,15 +458,17 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num):
...
@@ -445,15 +458,17 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num):
num_edges
=
12
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct CSCSamplingGraph.
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
)
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
)
fanouts
=
torch
.
LongTensor
(
fanouts
)
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
)
# Verify in subgraph.
# Verify in subgraph.
sampled_num
=
subgraph
.
indices
.
size
(
0
)
sampled_num
=
subgraph
.
indices
.
size
(
0
)
...
@@ -488,7 +503,9 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
...
@@ -488,7 +503,9 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanout
=
4
,
replace
=
replace
)
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
=
torch
.
LongTensor
([
4
]),
replace
=
replace
)
# Verify in subgraph.
# Verify in subgraph.
sampled_num
=
subgraph
.
indices
.
size
(
0
)
sampled_num
=
subgraph
.
indices
.
size
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment