Unverified Commit b3234516 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add sample etype neighbors (#5771)

parent 704aa423
...@@ -120,21 +120,28 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -120,21 +120,28 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* subgraph. * subgraph.
* *
* @param nodes The nodes from which to sample neighbors. * @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be * @param fanouts The number of edges to be sampled for each node with or
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater * without considering edge types.
* or equal to the number of neighbors and replacement is false, in which case * - When the length is 1, it indicates that the fanout applies to all
* all the neighbors will be selected. Otherwise, it will pick the minimum * neighbors of the node as a collective, regardless of the edge type.
* number of neighbors between the fanout value and the total number of * - Otherwise, the length should equal to the number of edge types, and
* neighbors. * each fanout value corresponds to a specific edge type of the node.
* The value of each fanout should be >= 0 or = -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or * @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple * without replacement. If True, a value can be selected multiple times.
* times.Otherwise, each value can be selected only once. * Otherwise, each value can be selected only once.
* *
* @return An intrusive pointer to a SampledSubgraph object containing the * @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information. * sampled graph's information.
*/ */
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, int64_t fanout, bool replace) const; const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace) const;
/** /**
* @brief Copy the graph to shared memory. * @brief Copy the graph to shared memory.
...@@ -216,14 +223,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -216,14 +223,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* node. * node.
* @param num_neighbors The number of neighbors to pick. * @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be * @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater * >= 0 or -1.
* or equal to the number of neighbors and replacement is false, in which case * - When the value is -1, all neighbors will be chosen for sampling. It is
* all the neighbors will be selected. Otherwise, it will pick the minimum * equivalent to selecting all neighbors when the fanout is >= the number of
* number of neighbors between the fanout value and the total number of * neighbors (and replacement is set to false).
* neighbors. * - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or * @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple * without replacement. If True, a value can be selected multiple times.
* times.Otherwise, each value can be selected only once. * Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result. * @param options Tensor options specifying the desired data type of the result.
* *
* @return A tensor containing the picked neighbors. * @return A tensor containing the picked neighbors.
...@@ -232,6 +240,34 @@ torch::Tensor Pick( ...@@ -232,6 +240,34 @@ torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options); const torch::TensorOptions& options);
/**
* @brief Picks a specified number of neighbors for a node per edge type,
* starting from the given offset and having the specified number of neighbors.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanouts The edge sampling numbers corresponding to each edge type for
* a single node. The value of each fanout should be >= 0 or = 1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum threshold
* for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param type_per_edge Tensor representing the type of each edge in the
* original graph.
*
* @return A tensor containing the picked neighbors.
*/
torch::Tensor PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge);
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
......
...@@ -122,9 +122,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph( ...@@ -122,9 +122,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
} }
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, int64_t fanout, bool replace) const { const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace) const {
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
bool consider_etype = (fanouts.size() > 1);
std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes); std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes);
torch::Tensor num_picked_neighbors_per_node = torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options()); torch::zeros({num_nodes + 1}, indptr_.options());
...@@ -147,8 +150,14 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( ...@@ -147,8 +150,14 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
continue; continue;
} }
picked_neighbors_per_node[i] = if (consider_etype) {
Pick(offset, num_neighbors, fanout, replace, indptr_.options()); picked_neighbors_per_node[i] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(),
type_per_edge_.value());
} else {
picked_neighbors_per_node[i] =
Pick(offset, num_neighbors, fanouts[0], replace, indptr_.options());
}
num_picked_neighbors_per_node[i + 1] = num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0); picked_neighbors_per_node[i].size(0);
} }
...@@ -214,5 +223,31 @@ torch::Tensor Pick( ...@@ -214,5 +223,31 @@ torch::Tensor Pick(
return picked_neighbors; return picked_neighbors;
} }
torch::Tensor PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge) {
std::vector<torch::Tensor> picked_neighbors(
fanouts.size(), torch::tensor({}, options));
int64_t etype_begin = offset;
int64_t etype_end = offset;
while (etype_end < offset + num_neighbors) {
int64_t etype = type_per_edge[etype_end].item<int64_t>();
int64_t fanout = fanouts[etype];
while (etype_end < offset + num_neighbors &&
type_per_edge[etype_end].item<int64_t>() == etype) {
etype_end++;
}
// Do sampling for one etype.
if (fanout != 0) {
picked_neighbors[etype] =
Pick(etype_begin, etype_end - etype_begin, fanout, replace, options);
}
etype_begin = etype_end;
}
return torch::cat(picked_neighbors, 0);
}
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
...@@ -197,7 +197,7 @@ class CSCSamplingGraph: ...@@ -197,7 +197,7 @@ class CSCSamplingGraph:
def sample_neighbors( def sample_neighbors(
self, self,
nodes: torch.Tensor, nodes: torch.Tensor,
fanout: int, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
) -> torch.ScriptObject: ) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
...@@ -207,22 +207,42 @@ class CSCSamplingGraph: ...@@ -207,22 +207,42 @@ class CSCSamplingGraph:
---------- ----------
nodes: torch.Tensor nodes: torch.Tensor
IDs of the given seed nodes. IDs of the given seed nodes.
fanout: int fanouts: torch.Tensor
The number of edges to be sampled for each node. It should be The number of edges to be sampled for each node with or without
>= 0 or -1. If -1 is given, it is equivalent to when the fanout considering edge types.
is greater or equal to the number of neighbors and replacement - When the length is 1, it indicates that the fanout applies to
is false, in which case all the neighbors will be selected. all neighbors of the node as a collective, regardless of the
Otherwise, it will pick the minimum number of neighbors between edge type.
the fanout value and the total number of neighbors. - Otherwise, the length should equal to the number of edge
replace: bool types, and each fanout value corresponds to a specific edge
type of the nodes.
The value of each fanout should be >= 0 or = -1.
- When the value is -1, all neighbors will be chosen for
sampling. It is equivalent to selecting all neighbors when
the fanout is >= the number of neighbors (and replacement
is set to false).
- When the value is a non-negative integer, it serves as a
minimum threshold for selecting neighbors.
replce: bool
Boolean indicating whether the sample is preformed with or Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once. times. Otherwise, each value can be selected only once.
""" """
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanout >= 0 or fanout == -1, "Fanout shoud have value >= 0 or -1" assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
return self._c_csc_graph.sample_neighbors(nodes, fanout, replace) if fanouts.size(0) > 1:
assert (
self.type_per_edge is not None
), "To perform sampling for each edge type (when the length of \
`fanouts` > 1), the graph must include edge type information."
assert torch.all(
(fanouts >= 0) | (fanouts == -1)
), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0."
return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace
)
def copy_to_shared_memory(self, shared_memory_name: str): def copy_to_shared_memory(self, shared_memory_name: str):
"""Copy the graph to shared memory. """Copy the graph to shared memory.
......
...@@ -402,21 +402,23 @@ def test_sample_neighbors(): ...@@ -402,21 +402,23 @@ def test_sample_neighbors():
num_edges = 12 num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1])
assert indptr[-1] == num_edges assert indptr[-1] == num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph. # Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices) graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
fanout = -1 fanouts = torch.tensor([2, 2, 3])
subgraph = graph.sample_neighbors(nodes, fanout) subgraph = graph.sample_neighbors(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7])) assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
assert torch.equal( assert torch.equal(
subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4]) torch.sort(subgraph.indices)[0],
torch.sort(torch.LongTensor([2, 3, 1, 2, 0, 3, 4]))[0],
) )
assert torch.equal(subgraph.reverse_column_node_ids, nodes) assert torch.equal(subgraph.reverse_column_node_ids, nodes)
assert subgraph.reverse_row_node_ids is None assert subgraph.reverse_row_node_ids is None
...@@ -429,10 +431,21 @@ def test_sample_neighbors(): ...@@ -429,10 +431,21 @@ def test_sample_neighbors():
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fanout, expected_sampled_num", "fanouts, expected_sampled_num",
[(0, 0), (1, 3), (2, 6), (3, 7), (4, 7), (-1, 7)], [
([0], 0),
([1], 3),
([2], 6),
([4], 7),
([-1], 7),
([0, 0], 0),
([1, 0], 3),
([1, 1], 6),
([2, 2], 7),
([-1, -1], 7),
],
) )
def test_sample_neighbors_fanout(fanout, expected_sampled_num): def test_sample_neighbors_fanouts(fanouts, expected_sampled_num):
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -445,15 +458,17 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num): ...@@ -445,15 +458,17 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num):
num_edges = 12 num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1])
assert indptr[-1] == num_edges assert indptr[-1] == num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph. # Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices) graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanout) fanouts = torch.LongTensor(fanouts)
subgraph = graph.sample_neighbors(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.indices.size(0) sampled_num = subgraph.indices.size(0)
...@@ -488,7 +503,9 @@ def test_sample_neighbors_replace(replace, expected_sampled_num): ...@@ -488,7 +503,9 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanout=4, replace=replace) subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([4]), replace=replace
)
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.indices.size(0) sampled_num = subgraph.indices.size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment