Unverified Commit 80d16efa authored by keli-wen's avatar keli-wen Committed by GitHub
Browse files

[Graphbolt] Add `cat` optimization for UniformPick (#6030)

parent 327589c8
...@@ -141,15 +141,27 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( ...@@ -141,15 +141,27 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
// If true, perform sampling for each edge type of each node, otherwise just // If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types. // sample once for each node with no regard of edge types.
bool consider_etype = (fanouts.size() > 1); bool consider_etype = (fanouts.size() > 1);
std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes); const int64_t num_threads = torch::get_num_threads();
std::vector<torch::Tensor> picked_neighbors_per_thread(num_threads);
torch::Tensor num_picked_neighbors_per_node = torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options()); torch::zeros({num_nodes + 1}, indptr_.options());
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const int64_t grain_size = 64;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "parallel_for", ([&] { indptr_.scalar_type(), "parallel_for", ([&] {
torch::parallel_for(0, num_nodes, 32, [&](scalar_t b, scalar_t e) { torch::parallel_for(
0, num_nodes, grain_size, [&](scalar_t begin, scalar_t end) {
const auto indptr_options = indptr_.options();
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>(); const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
for (scalar_t i = b; i < e; ++i) { // Get current thread id.
auto thread_id = torch::get_thread_num();
int64_t local_grain_size = end - begin;
std::vector<torch::Tensor> picked_neighbors_cur_thread(
local_grain_size);
for (scalar_t i = begin; i < end; ++i) {
const auto nid = nodes[i].item<int64_t>(); const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK( TORCH_CHECK(
nid >= 0 && nid < NumNodes(), nid >= 0 && nid < NumNodes(),
...@@ -159,32 +171,33 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( ...@@ -159,32 +171,33 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const auto num_neighbors = indptr_data[nid + 1] - offset; const auto num_neighbors = indptr_data[nid + 1] - offset;
if (num_neighbors == 0) { if (num_neighbors == 0) {
// To avoid crashing during concatenation in the master thread, // To avoid crashing during concatenation in the master
// initializing with empty tensors. // thread, initializing with empty tensors.
picked_neighbors_per_node[i] = picked_neighbors_cur_thread[i - begin] =
torch::tensor({}, indptr_.options()); torch::tensor({}, indptr_options);
continue; continue;
} }
if (consider_etype) { if (consider_etype) {
picked_neighbors_per_node[i] = PickByEtype( picked_neighbors_cur_thread[i - begin] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(), offset, num_neighbors, fanouts, replace, indptr_options,
type_per_edge_.value(), probs_or_mask, args); type_per_edge_.value(), probs_or_mask, args);
} else { } else {
picked_neighbors_per_node[i] = Pick( picked_neighbors_cur_thread[i - begin] = Pick(
offset, num_neighbors, fanouts[0], replace, indptr_.options(), offset, num_neighbors, fanouts[0], replace,
probs_or_mask, args); indptr_options, probs_or_mask, args);
} }
num_picked_neighbors_per_node[i + 1] = num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0); picked_neighbors_cur_thread[i - begin].size(0);
} }
}); // End of the thread. picked_neighbors_per_thread[thread_id] =
torch::cat(picked_neighbors_cur_thread);
}); // End of parallel_for.
})); }));
torch::Tensor subgraph_indptr = torch::Tensor subgraph_indptr =
torch::cumsum(num_picked_neighbors_per_node, 0); torch::cumsum(num_picked_neighbors_per_node, 0);
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node); torch::Tensor picked_eids = torch::cat(picked_neighbors_per_thread);
torch::Tensor subgraph_indices = torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids); torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt; torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment