Unverified Commit e1781586 authored by keli-wen's avatar keli-wen Committed by GitHub
Browse files

[Graphbolt] Add optimization for UniformPick in Sampling. (#5932)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent e2c7dd32
......@@ -11,6 +11,7 @@
#include <tuple>
#include <vector>
#include "./random.h"
#include "./shared_memory_utils.h"
namespace graphbolt {
......@@ -268,8 +269,91 @@ inline torch::Tensor UniformPick(
picked_neighbors =
torch::randint(offset, offset + num_neighbors, {fanout}, options);
} else {
picked_neighbors = torch::randperm(num_neighbors, options);
picked_neighbors = picked_neighbors.slice(0, 0, fanout) + offset;
picked_neighbors = torch::empty({fanout}, options);
AT_DISPATCH_INTEGRAL_TYPES(
picked_neighbors.scalar_type(), "UniformPick", ([&] {
scalar_t* picked_neighbors_data =
picked_neighbors.data_ptr<scalar_t>();
// We use different sampling strategies for different sampling case.
if (fanout >= num_neighbors / 10) {
// [Algorithm]
// This algorithm is conceptually related to the Fisher-Yates
// shuffle.
//
// [Complexity Analysis]
// This algorithm's memory complexity is O(num_neighbors), but
// it generates fewer random numbers (O(fanout)).
//
// (Compare) Reservoir algorithm is one of the most classical
// sampling algorithms. Both the reservoir algorithm and our
// algorithm offer distinct advantages, we need to compare to
// illustrate our trade-offs.
// The reservoir algorithm is memory-efficient (O(fanout)) but
// creates many random numbers (O(num_neighbors)), which is
// costly.
//
// [Practical Consideration]
// Use this algorithm when `fanout >= num_neighbors / 10` to
// reduce computation.
// In this scenarios above, memory complexity is not a concern due
// to the small size of both `fanout` and `num_neighbors`. And it
// is efficient to allocate a small amount of memory. So the
// algorithm performence is great in this case.
std::vector<scalar_t> seq(num_neighbors);
// Assign the seq with [offset, offset + num_neighbors].
std::iota(seq.begin(), seq.end(), offset);
for (int64_t i = 0; i < fanout; ++i) {
auto j = RandomEngine::ThreadLocal()->RandInt(i, num_neighbors);
std::swap(seq[i], seq[j]);
}
// Save the randomly sampled fanout elements to the output tensor.
std::copy(seq.begin(), seq.begin() + fanout, picked_neighbors_data);
} else if (fanout < 64) {
// [Algorithm]
// Use linear search to verify uniqueness.
//
// [Complexity Analysis]
// Since the set of numbers is small (up to 64), so it is more
// cost-effective for the CPU to use this algorithm.
auto begin = picked_neighbors_data;
auto end = picked_neighbors_data + fanout;
while (begin != end) {
// Put the new random number in the last position.
*begin = RandomEngine::ThreadLocal()->RandInt(
offset, offset + num_neighbors);
// Check if a new value doesn't exist in current
// range(picked_neighbors_data, begin). Otherwise get a new
// value until we haven't unique range of elements.
auto it = std::find(picked_neighbors_data, begin, *begin);
if (it == begin) ++begin;
}
} else {
// [Algorithm]
// Use hash-set to verify uniqueness. In the best scenario, the
// time complexity is O(fanout), assuming no conflicts occur.
//
// [Complexity Analysis]
// Let K = (fanout / num_neighbors), the expected number of extra
// sampling steps is roughly K^2 / (1-K) * num_neighbors, which
// means in the worst case scenario, the time complexity is
// O(num_neighbors^2).
//
// [Practical Consideration]
// In practice, we set the threshold K to 1/10. This trade-off is
// due to the slower performance of std::unordered_set, which
// would otherwise increase the sampling cost. By doing so, we
// achieve a balance between theoretical efficiency and practical
// performance.
std::unordered_set<scalar_t> picked_set;
while (static_cast<int64_t>(picked_set.size()) < fanout) {
picked_set.insert(RandomEngine::ThreadLocal()->RandInt(
offset, offset + num_neighbors));
}
std::copy(
picked_set.begin(), picked_set.end(), picked_neighbors_data);
}
}));
}
return picked_neighbors;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment