Unverified Commit c9c165f7 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add fanout for neighbor sampling (#5768)

[Graphbolt] Add fanout for sampling
parent a99095e7
......@@ -120,12 +120,16 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* subgraph.
*
* @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
*/
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes) const;
const torch::Tensor& nodes, int64_t fanout) const;
/**
* @brief Copy the graph to shared memory.
......@@ -199,6 +203,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
SharedMemoryPtr tensor_meta_shm_, tensor_data_shm_;
};
/**
* @brief Picks a specified number of neighbors for a node, starting from the
* given offset and having the specified number of neighbors.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
* @param options Tensor options specifying the desired data type of the result.
* @return A tensor containing the picked neighbors.
*/
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options);
} // namespace sampling
} // namespace graphbolt
......
......@@ -122,7 +122,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
}
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes) const {
const torch::Tensor& nodes, int64_t fanout) const {
const int64_t num_nodes = nodes.size(0);
std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes);
......@@ -148,8 +148,9 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
}
picked_neighbors_per_node[i] =
torch::arange(offset, offset + num_neighbors);
num_picked_neighbors_per_node[i + 1] = num_neighbors;
Pick(offset, num_neighbors, fanout, indptr_.options());
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
}
}); // End of the thread.
......@@ -195,5 +196,18 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
return BuildGraphFromSharedMemoryTensors(std::move(shared_memory_tensors));
}
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options) {
torch::Tensor picked_neighbors;
if ((fanout == -1) || (num_neighbors <= fanout)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options);
} else {
picked_neighbors = torch::randperm(num_neighbors) + offset;
picked_neighbors = picked_neighbors.slice(0, 0, fanout);
}
return picked_neighbors;
}
} // namespace sampling
} // namespace graphbolt
......@@ -194,7 +194,9 @@ class CSCSamplingGraph:
), "Nodes cannot have duplicate values."
return self._c_csc_graph.in_subgraph(nodes)
def sample_neighbors(self, nodes: torch.Tensor) -> torch.ScriptObject:
def sample_neighbors(
self, nodes: torch.Tensor, fanout: int
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -202,10 +204,16 @@ class CSCSamplingGraph:
----------
nodes: torch.Tensor
IDs of the given seed nodes.
fanout: int
The number of edges to be sampled for each node. It should be
>= 0 or -1. If -1 is given, all neighbors will be selected.
Otherwise, it will pick the minimum number of neighbors between
the fanout value and the total number of neighbors.
"""
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
return self._c_csc_graph.sample_neighbors(nodes)
assert fanout >= 0 or fanout == -1, "Fanout shoud have value >= 0 or -1"
return self._c_csc_graph.sample_neighbors(nodes, fanout)
def copy_to_shared_memory(self, shared_memory_name: str):
"""Copy the graph to shared memory.
......
......@@ -410,7 +410,8 @@ def test_sample_neighbors():
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes)
fanout = -1
subgraph = graph.sample_neighbors(nodes, fanout)
# Verify in subgraph.
assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
......@@ -423,6 +424,42 @@ def test_sample_neighbors():
assert subgraph.type_per_edge is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"fanout, expected_sampled_num",
[(0, 0), (1, 3), (2, 6), (3, 7), (4, 7), (-1, 7)],
)
def test_sample_neighbors_fanout(fanout, expected_sampled_num):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanout)
# Verify in subgraph.
sampled_num = subgraph.indices.size(0)
assert sampled_num == expected_sampled_num
def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
"""Check if two tensors are on the same shared memory.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment