Unverified Commit 704aa423 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add replace for neighbor sampling (#5770)

[Graphbolt] Add replace for sampling
parent c9c165f7
......@@ -121,15 +121,20 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*
* @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* or equal to the number of neighbors and replacement is false, in which case
* all the neighbors will be selected. Otherwise, it will pick the minimum
* number of neighbors between the fanout value and the total number of
* neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* times.Otherwise, each value can be selected only once.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
*/
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, int64_t fanout) const;
const torch::Tensor& nodes, int64_t fanout, bool replace) const;
/**
* @brief Copy the graph to shared memory.
......@@ -211,14 +216,20 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
* >= 0 or -1. If -1 is given, it is equivalent to when the fanout is greater
* or equal to the number of neighbors and replacement is false, in which case
* all the neighbors will be selected. Otherwise, it will pick the minimum
* number of neighbors between the fanout value and the total number of
* neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple
* times.Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
*
* @return A tensor containing the picked neighbors.
*/
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options);
} // namespace sampling
......
......@@ -122,7 +122,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
}
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, int64_t fanout) const {
const torch::Tensor& nodes, int64_t fanout, bool replace) const {
const int64_t num_nodes = nodes.size(0);
std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes);
......@@ -148,7 +148,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
}
picked_neighbors_per_node[i] =
Pick(offset, num_neighbors, fanout, indptr_.options());
Pick(offset, num_neighbors, fanout, replace, indptr_.options());
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
}
......@@ -197,15 +197,20 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
}
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options) {
torch::Tensor picked_neighbors;
if ((fanout == -1) || (num_neighbors <= fanout)) {
if ((fanout == -1) || (num_neighbors <= fanout && !replace)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options);
} else {
picked_neighbors = torch::randperm(num_neighbors) + offset;
if (replace) {
picked_neighbors =
torch::randint(offset, offset + num_neighbors, {fanout}, options);
} else {
picked_neighbors = torch::randperm(num_neighbors, options) + offset;
picked_neighbors = picked_neighbors.slice(0, 0, fanout);
}
}
return picked_neighbors;
}
......
......@@ -195,7 +195,10 @@ class CSCSamplingGraph:
return self._c_csc_graph.in_subgraph(nodes)
def sample_neighbors(
self, nodes: torch.Tensor, fanout: int
self,
nodes: torch.Tensor,
fanout: int,
replace: bool = False,
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -206,14 +209,20 @@ class CSCSamplingGraph:
IDs of the given seed nodes.
fanout: int
The number of edges to be sampled for each node. It should be
>= 0 or -1. If -1 is given, all neighbors will be selected.
>= 0 or -1. If -1 is given, it is equivalent to when the fanout
is greater or equal to the number of neighbors and replacement
is false, in which case all the neighbors will be selected.
Otherwise, it will pick the minimum number of neighbors between
the fanout value and the total number of neighbors.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
"""
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanout >= 0 or fanout == -1, "Fanout shoud have value >= 0 or -1"
return self._c_csc_graph.sample_neighbors(nodes, fanout)
return self._c_csc_graph.sample_neighbors(nodes, fanout, replace)
def copy_to_shared_memory(self, shared_memory_name: str):
"""Copy the graph to shared memory.
......
......@@ -460,6 +460,41 @@ def test_sample_neighbors_fanout(fanout, expected_sampled_num):
assert sampled_num == expected_sampled_num
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"replace, expected_sampled_num", [(False, 7), (True, 12)]
)
def test_sample_neighbors_replace(replace, expected_sampled_num):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanout=4, replace=replace)
# Verify in subgraph.
sampled_num = subgraph.indices.size(0)
assert sampled_num == expected_sampled_num
def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
"""Check if two tensors are on the same shared memory.
......@@ -574,3 +609,7 @@ def test_hetero_graph_on_shared_memory(
assert metadata.edge_type_to_id == graph1.metadata.edge_type_to_id
assert metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
if __name__ == "__main__":
test_sample_neighbors_replace(True, 12)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment