"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "99b30a049ecfd98e37244672b1d8db0774ecb9b4"
Unverified Commit 2c03fe99 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Dispatch edge ids in neighbor sampling (#5889)

parent 2489f579
......@@ -141,37 +141,42 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options());
torch::parallel_for(0, num_nodes, 32, [&](size_t b, size_t e) {
for (size_t i = b; i < e; ++i) {
const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the graph's "
"node IDs.");
const auto offset = indptr_[nid].item<int64_t>();
const auto num_neighbors = indptr_[nid + 1].item<int64_t>() - offset;
AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "parallel_for", ([&] {
torch::parallel_for(0, num_nodes, 32, [&](scalar_t b, scalar_t e) {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
for (scalar_t i = b; i < e; ++i) {
const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the "
"graph's node IDs.");
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
if (num_neighbors == 0) {
// Initialization is performed here because all tensors will be
// concatenated in the master thread, and having an undefined tensor
// during concatenation can result in a crash.
picked_neighbors_per_node[i] = torch::tensor({}, indptr_.options());
continue;
}
if (num_neighbors == 0) {
// Initialization is performed here because all tensors will be
// concatenated in the master thread, and having an undefined
// tensor during concatenation can result in a crash.
picked_neighbors_per_node[i] =
torch::tensor({}, indptr_.options());
continue;
}
if (consider_etype) {
picked_neighbors_per_node[i] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(),
type_per_edge_.value(), probs_or_mask);
} else {
picked_neighbors_per_node[i] = Pick(
offset, num_neighbors, fanouts[0], replace, indptr_.options(),
probs_or_mask);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
}
}); // End of the thread.
if (consider_etype) {
picked_neighbors_per_node[i] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(),
type_per_edge_.value(), probs_or_mask);
} else {
picked_neighbors_per_node[i] = Pick(
offset, num_neighbors, fanouts[0], replace, indptr_.options(),
probs_or_mask);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
}
}); // End of the thread.
}));
torch::Tensor subgraph_indptr =
torch::cumsum(num_picked_neighbors_per_node, 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment