Unverified Commit 2c03fe99 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Dispatch edge ids in neighbor sampling (#5889)

parent 2489f579
...@@ -141,21 +141,25 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( ...@@ -141,21 +141,25 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch::Tensor num_picked_neighbors_per_node = torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options()); torch::zeros({num_nodes + 1}, indptr_.options());
torch::parallel_for(0, num_nodes, 32, [&](size_t b, size_t e) { AT_DISPATCH_INTEGRAL_TYPES(
for (size_t i = b; i < e; ++i) { indptr_.scalar_type(), "parallel_for", ([&] {
torch::parallel_for(0, num_nodes, 32, [&](scalar_t b, scalar_t e) {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
for (scalar_t i = b; i < e; ++i) {
const auto nid = nodes[i].item<int64_t>(); const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK( TORCH_CHECK(
nid >= 0 && nid < NumNodes(), nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the graph's " "The seed nodes' IDs should fall within the range of the "
"node IDs."); "graph's node IDs.");
const auto offset = indptr_[nid].item<int64_t>(); const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_[nid + 1].item<int64_t>() - offset; const auto num_neighbors = indptr_data[nid + 1] - offset;
if (num_neighbors == 0) { if (num_neighbors == 0) {
// Initialization is performed here because all tensors will be // Initialization is performed here because all tensors will be
// concatenated in the master thread, and having an undefined tensor // concatenated in the master thread, and having an undefined
// during concatenation can result in a crash. // tensor during concatenation can result in a crash.
picked_neighbors_per_node[i] = torch::tensor({}, indptr_.options()); picked_neighbors_per_node[i] =
torch::tensor({}, indptr_.options());
continue; continue;
} }
...@@ -172,6 +176,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( ...@@ -172,6 +176,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
picked_neighbors_per_node[i].size(0); picked_neighbors_per_node[i].size(0);
} }
}); // End of the thread. }); // End of the thread.
}));
torch::Tensor subgraph_indptr = torch::Tensor subgraph_indptr =
torch::cumsum(num_picked_neighbors_per_node, 0); torch::cumsum(num_picked_neighbors_per_node, 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment