Unverified Commit a99095e7 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt]Simple neighbor sampling that picking all neighbors for given nodes (#5765)

[GraphBolt] Add neighbor sampling picking all neighbors for given nodes
parent 9ff56d20
...@@ -123,11 +123,46 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph( ...@@ -123,11 +123,46 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes) const { const torch::Tensor& nodes) const {
// TODO(#5692): implement this. const int64_t num_nodes = nodes.size(0);
std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes);
torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options());
torch::parallel_for(0, num_nodes, 32, [&](size_t b, size_t e) {
for (size_t i = b; i < e; ++i) {
const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the graph's "
"node IDs.");
const auto offset = indptr_[nid].item<int64_t>();
const auto num_neighbors = indptr_[nid + 1].item<int64_t>() - offset;
if (num_neighbors == 0) {
// Initialization is performed here because all tensors will be
// concatenated in the master thread, and having an undefined tensor
// during concatenation can result in a crash.
picked_neighbors_per_node[i] = torch::tensor({}, indptr_.options());
continue;
}
picked_neighbors_per_node[i] =
torch::arange(offset, offset + num_neighbors);
num_picked_neighbors_per_node[i + 1] = num_neighbors;
}
}); // End of the thread.
torch::Tensor subgraph_indptr =
torch::cumsum(num_picked_neighbors_per_node, 0);
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node);
torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids);
return c10::make_intrusive<SampledSubgraph>( return c10::make_intrusive<SampledSubgraph>(
torch::zeros({nodes.size(0) + 1}, indptr_.options()), subgraph_indptr, subgraph_indices, nodes, torch::nullopt, torch::nullopt,
torch::zeros({1}, indptr_.options()), nodes, torch::nullopt, torch::nullopt);
torch::nullopt, torch::nullopt);
} }
c10::intrusive_ptr<CSCSamplingGraph> c10::intrusive_ptr<CSCSamplingGraph>
......
...@@ -413,8 +413,10 @@ def test_sample_neighbors(): ...@@ -413,8 +413,10 @@ def test_sample_neighbors():
subgraph = graph.sample_neighbors(nodes) subgraph = graph.sample_neighbors(nodes)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal(subgraph.indptr, torch.LongTensor([0, 0, 0, 0])) assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
assert torch.equal(subgraph.indices, torch.LongTensor([0])) assert torch.equal(
subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
)
assert torch.equal(subgraph.reverse_column_node_ids, nodes) assert torch.equal(subgraph.reverse_column_node_ids, nodes)
assert subgraph.reverse_row_node_ids is None assert subgraph.reverse_row_node_ids is None
assert subgraph.reverse_edge_ids is None assert subgraph.reverse_edge_ids is None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment