Unverified Commit 80d16efa authored by keli-wen's avatar keli-wen Committed by GitHub
Browse files

[Graphbolt] Add `cat` optimization for UniformPick (#6030)

parent 327589c8
...@@ -141,50 +141,63 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( ...@@ -141,50 +141,63 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
// If true, perform sampling for each edge type of each node, otherwise just // If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types. // sample once for each node with no regard of edge types.
bool consider_etype = (fanouts.size() > 1); bool consider_etype = (fanouts.size() > 1);
std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes); const int64_t num_threads = torch::get_num_threads();
std::vector<torch::Tensor> picked_neighbors_per_thread(num_threads);
torch::Tensor num_picked_neighbors_per_node = torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options()); torch::zeros({num_nodes + 1}, indptr_.options());
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const int64_t grain_size = 64;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "parallel_for", ([&] { indptr_.scalar_type(), "parallel_for", ([&] {
torch::parallel_for(0, num_nodes, 32, [&](scalar_t b, scalar_t e) { torch::parallel_for(
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>(); 0, num_nodes, grain_size, [&](scalar_t begin, scalar_t end) {
for (scalar_t i = b; i < e; ++i) { const auto indptr_options = indptr_.options();
const auto nid = nodes[i].item<int64_t>(); const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
TORCH_CHECK( // Get current thread id.
nid >= 0 && nid < NumNodes(), auto thread_id = torch::get_thread_num();
"The seed nodes' IDs should fall within the range of the " int64_t local_grain_size = end - begin;
"graph's node IDs."); std::vector<torch::Tensor> picked_neighbors_cur_thread(
const auto offset = indptr_data[nid]; local_grain_size);
const auto num_neighbors = indptr_data[nid + 1] - offset;
if (num_neighbors == 0) { for (scalar_t i = begin; i < end; ++i) {
// To avoid crashing during concatenation in the master thread, const auto nid = nodes[i].item<int64_t>();
// initializing with empty tensors. TORCH_CHECK(
picked_neighbors_per_node[i] = nid >= 0 && nid < NumNodes(),
torch::tensor({}, indptr_.options()); "The seed nodes' IDs should fall within the range of the "
continue; "graph's node IDs.");
} const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
if (consider_etype) { if (num_neighbors == 0) {
picked_neighbors_per_node[i] = PickByEtype( // To avoid crashing during concatenation in the master
offset, num_neighbors, fanouts, replace, indptr_.options(), // thread, initializing with empty tensors.
type_per_edge_.value(), probs_or_mask, args); picked_neighbors_cur_thread[i - begin] =
} else { torch::tensor({}, indptr_options);
picked_neighbors_per_node[i] = Pick( continue;
offset, num_neighbors, fanouts[0], replace, indptr_.options(), }
probs_or_mask, args);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
}
}); // End of the thread.
}));
if (consider_etype) {
picked_neighbors_cur_thread[i - begin] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_options,
type_per_edge_.value(), probs_or_mask, args);
} else {
picked_neighbors_cur_thread[i - begin] = Pick(
offset, num_neighbors, fanouts[0], replace,
indptr_options, probs_or_mask, args);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_cur_thread[i - begin].size(0);
}
picked_neighbors_per_thread[thread_id] =
torch::cat(picked_neighbors_cur_thread);
}); // End of parallel_for.
}));
torch::Tensor subgraph_indptr = torch::Tensor subgraph_indptr =
torch::cumsum(num_picked_neighbors_per_node, 0); torch::cumsum(num_picked_neighbors_per_node, 0);
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node); torch::Tensor picked_eids = torch::cat(picked_neighbors_per_thread);
torch::Tensor subgraph_indices = torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids); torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt; torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment