Unverified Commit b3234516 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add sample etype neighbors (#5771)

parent 704aa423
......@@ -120,21 +120,28 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* subgraph.
*
* @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* or equal to the number of neighbors and replacement is false, in which case
* all the neighbors will be selected. Otherwise, it will pick the minimum
* number of neighbors between the fanout value and the total number of
* neighbors.
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
* neighbors of the node as a collective, regardless of the edge type.
* - Otherwise, the length should equal to the number of edge types, and
* each fanout value corresponds to a specific edge type of the node.
* The value of each fanout should be >= 0 or = -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* times.Otherwise, each value can be selected only once.
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
*/
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, int64_t fanout, bool replace) const;
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace) const;
/**
* @brief Copy the graph to shared memory.
......@@ -216,14 +223,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* or equal to the number of neighbors and replacement is false, in which case
* all the neighbors will be selected. Otherwise, it will pick the minimum
* number of neighbors between the fanout value and the total number of
* neighbors.
* >= 0 or -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* times.Otherwise, each value can be selected only once.
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
*
* @return A tensor containing the picked neighbors.
......@@ -232,6 +240,34 @@ torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options);
/**
* @brief Picks a specified number of neighbors for a node per edge type,
* starting from the given offset and having the specified number of neighbors.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanouts The edge sampling numbers corresponding to each edge type for
* a single node. The value of each fanout should be >= 0 or = 1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum threshold
* for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param type_per_edge Tensor representing the type of each edge in the
* original graph.
*
* @return A tensor containing the picked neighbors.
*/
torch::Tensor PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge);
} // namespace sampling
} // namespace graphbolt
......
......@@ -122,9 +122,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
}
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, int64_t fanout, bool replace) const {
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace) const {
const int64_t num_nodes = nodes.size(0);
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
bool consider_etype = (fanouts.size() > 1);
std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes);
torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options());
......@@ -147,8 +150,14 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
continue;
}
if (consider_etype) {
picked_neighbors_per_node[i] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(),
type_per_edge_.value());
} else {
picked_neighbors_per_node[i] =
Pick(offset, num_neighbors, fanout, replace, indptr_.options());
Pick(offset, num_neighbors, fanouts[0], replace, indptr_.options());
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
}
......@@ -214,5 +223,31 @@ torch::Tensor Pick(
return picked_neighbors;
}
torch::Tensor PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge) {
std::vector<torch::Tensor> picked_neighbors(
fanouts.size(), torch::tensor({}, options));
int64_t etype_begin = offset;
int64_t etype_end = offset;
while (etype_end < offset + num_neighbors) {
int64_t etype = type_per_edge[etype_end].item<int64_t>();
int64_t fanout = fanouts[etype];
while (etype_end < offset + num_neighbors &&
type_per_edge[etype_end].item<int64_t>() == etype) {
etype_end++;
}
// Do sampling for one etype.
if (fanout != 0) {
picked_neighbors[etype] =
Pick(etype_begin, etype_end - etype_begin, fanout, replace, options);
}
etype_begin = etype_end;
}
return torch::cat(picked_neighbors, 0);
}
} // namespace sampling
} // namespace graphbolt
......@@ -197,7 +197,7 @@ class CSCSamplingGraph:
def sample_neighbors(
self,
nodes: torch.Tensor,
fanout: int,
fanouts: torch.Tensor,
replace: bool = False,
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
......@@ -207,22 +207,42 @@ class CSCSamplingGraph:
----------
nodes: torch.Tensor
IDs of the given seed nodes.
fanout: int
The number of edges to be sampled for each node. It should be
>= 0 or -1. If -1 is given, it is equivalent to when the fanout
is greater or equal to the number of neighbors and replacement
is false, in which case all the neighbors will be selected.
Otherwise, it will pick the minimum number of neighbors between
the fanout value and the total number of neighbors.
replace: bool
fanouts: torch.Tensor
The number of edges to be sampled for each node with or without
considering edge types.
- When the length is 1, it indicates that the fanout applies to
all neighbors of the node as a collective, regardless of the
edge type.
- Otherwise, the length should equal to the number of edge
types, and each fanout value corresponds to a specific edge
type of the nodes.
The value of each fanout should be >= 0 or = -1.
- When the value is -1, all neighbors will be chosen for
sampling. It is equivalent to selecting all neighbors when
the fanout is >= the number of neighbors (and replacement
is set to false).
- When the value is a non-negative integer, it serves as a
minimum threshold for selecting neighbors.
replce: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
"""
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanout >= 0 or fanout == -1, "Fanout shoud have value >= 0 or -1"
return self._c_csc_graph.sample_neighbors(nodes, fanout, replace)
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
if fanouts.size(0) > 1:
assert (
self.type_per_edge is not None
), "To perform sampling for each edge type (when the length of \
`fanouts` > 1), the graph must include edge type information."
assert torch.all(
(fanouts >= 0) | (fanouts == -1)
), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0."
return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace
)
def copy_to_shared_memory(self, shared_memory_name: str):
"""Copy the graph to shared memory.
......
......@@ -402,21 +402,23 @@ def test_sample_neighbors():
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
fanout = -1
subgraph = graph.sample_neighbors(nodes, fanout)
fanouts = torch.tensor([2, 2, 3])
subgraph = graph.sample_neighbors(nodes, fanouts)
# Verify in subgraph.
assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
assert torch.equal(
subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
torch.sort(subgraph.indices)[0],
torch.sort(torch.LongTensor([2, 3, 1, 2, 0, 3, 4]))[0],
)
assert torch.equal(subgraph.reverse_column_node_ids, nodes)
assert subgraph.reverse_row_node_ids is None
......@@ -429,10 +431,21 @@ def test_sample_neighbors():
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"fanout, expected_sampled_num",
[(0, 0), (1, 3), (2, 6), (3, 7), (4, 7), (-1, 7)],
"fanouts, expected_sampled_num",
[
([0], 0),
([1], 3),
([2], 6),
([4], 7),
([-1], 7),
([0, 0], 0),
([1, 0], 3),
([1, 1], 6),
([2, 2], 7),
([-1, -1], 7),
],
)
def test_sample_neighbors_fanout(fanout, expected_sampled_num):
def test_sample_neighbors_fanouts(fanouts, expected_sampled_num):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -445,15 +458,17 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num):
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanout)
fanouts = torch.LongTensor(fanouts)
subgraph = graph.sample_neighbors(nodes, fanouts)
# Verify in subgraph.
sampled_num = subgraph.indices.size(0)
......@@ -488,7 +503,9 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanout=4, replace=replace)
subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([4]), replace=replace
)
# Verify in subgraph.
sampled_num = subgraph.indices.size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment