Unverified Commit 80d16efa authored by keli-wen's avatar keli-wen Committed by GitHub
Browse files

[Graphbolt] Add `cat` optimization for UniformPick (#6030)

parent 327589c8
......@@ -141,50 +141,63 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
bool consider_etype = (fanouts.size() > 1);
std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes);
const int64_t num_threads = torch::get_num_threads();
std::vector<torch::Tensor> picked_neighbors_per_thread(num_threads);
torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options());
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const int64_t grain_size = 64;
AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "parallel_for", ([&] {
torch::parallel_for(0, num_nodes, 32, [&](scalar_t b, scalar_t e) {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
for (scalar_t i = b; i < e; ++i) {
const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the "
"graph's node IDs.");
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
torch::parallel_for(
0, num_nodes, grain_size, [&](scalar_t begin, scalar_t end) {
const auto indptr_options = indptr_.options();
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
// Get current thread id.
auto thread_id = torch::get_thread_num();
int64_t local_grain_size = end - begin;
std::vector<torch::Tensor> picked_neighbors_cur_thread(
local_grain_size);
if (num_neighbors == 0) {
// To avoid crashing during concatenation in the master thread,
// initializing with empty tensors.
picked_neighbors_per_node[i] =
torch::tensor({}, indptr_.options());
continue;
}
for (scalar_t i = begin; i < end; ++i) {
const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the "
"graph's node IDs.");
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
if (consider_etype) {
picked_neighbors_per_node[i] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(),
type_per_edge_.value(), probs_or_mask, args);
} else {
picked_neighbors_per_node[i] = Pick(
offset, num_neighbors, fanouts[0], replace, indptr_.options(),
probs_or_mask, args);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
}
}); // End of the thread.
}));
if (num_neighbors == 0) {
// To avoid crashing during concatenation in the master
// thread, initializing with empty tensors.
picked_neighbors_cur_thread[i - begin] =
torch::tensor({}, indptr_options);
continue;
}
if (consider_etype) {
picked_neighbors_cur_thread[i - begin] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_options,
type_per_edge_.value(), probs_or_mask, args);
} else {
picked_neighbors_cur_thread[i - begin] = Pick(
offset, num_neighbors, fanouts[0], replace,
indptr_options, probs_or_mask, args);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_cur_thread[i - begin].size(0);
}
picked_neighbors_per_thread[thread_id] =
torch::cat(picked_neighbors_cur_thread);
}); // End of parallel_for.
}));
torch::Tensor subgraph_indptr =
torch::cumsum(num_picked_neighbors_per_node, 0);
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node);
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_thread);
torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment