Unverified Commit 0f3f8181 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Graphbolt] Improve Labor performance (#6203)

parent 27f6561a
......@@ -406,8 +406,8 @@ int64_t PickByEtype(
PickedType* picked_data_ptr);
template <
bool NonUniform, bool Replace, typename ProbsType = float,
typename PickedType>
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize = 1024>
int64_t LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
......
......@@ -8,6 +8,8 @@
#include <graphbolt/serialize.h>
#include <torch/torch.h>
#include <algorithm>
#include <array>
#include <cmath>
#include <limits>
#include <numeric>
......@@ -730,11 +732,11 @@ int64_t Pick(
return UniformPick(
offset, num_neighbors, fanout, replace, options, picked_data_ptr);
} else if (replace) {
return LaborPick<false, true>(
return LaborPick<false, true, float>(
offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
} else { // replace = false
return LaborPick<false, false>(
return LaborPick<false, false, float>(
offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
}
......@@ -770,7 +772,8 @@ inline void safe_divide(T& a, U b) {
* should be put. Enough memory space should be allocated in advance.
*/
template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType>
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize>
inline int64_t LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
......@@ -781,10 +784,16 @@ inline int64_t LaborPick(
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
return num_neighbors;
}
torch::Tensor heap_tensor = torch::empty({fanout * 2}, torch::kInt32);
// Assuming max_degree of a vertex is <= 4 billion.
auto heap_data = reinterpret_cast<std::pair<float, uint32_t>*>(
heap_tensor.data_ptr<int32_t>());
std::array<std::pair<float, uint32_t>, StackSize> heap;
auto heap_data = heap.data();
torch::Tensor heap_tensor;
if (fanout > StackSize) {
constexpr int factor = sizeof(heap_data[0]) / sizeof(int32_t);
heap_tensor = torch::empty({fanout * factor}, torch::kInt32);
heap_data = reinterpret_cast<std::pair<float, uint32_t>*>(
heap_tensor.data_ptr<int32_t>());
}
const ProbsType* local_probs_data =
NonUniform ? probs_or_mask.value().data_ptr<ProbsType>() + offset
: nullptr;
......@@ -814,22 +823,29 @@ inline int64_t LaborPick(
// is O((fanout + num_neighbors) log(fanout)). It is possible to
// decrease the logarithmic factor down to
// O(log(min(fanout, num_neighbors))).
torch::Tensor remaining =
torch::ones({num_neighbors}, torch::kFloat32);
float* rem_data = remaining.data_ptr<float>();
std::array<float, StackSize> remaining;
auto remaining_data = remaining.data();
torch::Tensor remaining_tensor;
if (num_neighbors > StackSize) {
remaining_tensor = torch::empty({num_neighbors}, torch::kFloat32);
remaining_data = remaining_tensor.data_ptr<float>();
}
std::fill_n(remaining_data, num_neighbors, 1.f);
auto heap_end = heap_data;
const auto init_count = (num_neighbors + fanout - 1) / num_neighbors;
auto sample_neighbor_i_with_index_t_jth_time =
[&](scalar_t t, int64_t j, uint32_t i) {
auto rnd = labor::jth_sorted_uniform_random(
args.random_seed, t, args.num_nodes, j, rem_data[i],
args.random_seed, t, args.num_nodes, j, remaining_data[i],
fanout - j); // r_t
if constexpr (NonUniform) {
safe_divide(rnd, local_probs_data[i]);
} // r_t / \pi_t
if (heap_end < heap_data + fanout) {
heap_end[0] = std::make_pair(rnd, i);
std::push_heap(heap_data, ++heap_end);
if (++heap_end >= heap_data + fanout) {
std::make_heap(heap_data, heap_data + fanout);
}
return false;
} else if (rnd < heap_data[0].first) {
std::pop_heap(heap_data, heap_data + fanout);
......@@ -837,18 +853,18 @@ inline int64_t LaborPick(
std::push_heap(heap_data, heap_data + fanout);
return false;
} else {
rem_data[i] = -1;
remaining_data[i] = -1;
return true;
}
};
for (uint32_t i = 0; i < num_neighbors; ++i) {
const auto t = local_indices_data[i];
for (int64_t j = 0; j < init_count; j++) {
const auto t = local_indices_data[i];
sample_neighbor_i_with_index_t_jth_time(t, j, i);
}
}
for (uint32_t i = 0; i < num_neighbors; ++i) {
if (rem_data[i] == -1) continue;
if (remaining_data[i] == -1) continue;
const auto t = local_indices_data[i];
for (int64_t j = init_count; j < fanout; ++j) {
if (sample_neighbor_i_with_index_t_jth_time(t, j, i)) break;
......@@ -906,9 +922,6 @@ inline int64_t LaborPick(
picked_data_ptr[num_sampled++] = offset + j;
}
}
TORCH_CHECK(
!Replace || num_sampled == fanout || num_sampled == 0,
"Sampling with replacement should sample exactly fanout neighbors or 0!");
return num_sampled;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment