Unverified Commit bbc8ff62 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Rewrite torch::multinomial to improve sampling performance (#6217)

parent 219c9f1a
......@@ -403,6 +403,12 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
probs_or_mask.value().dtype() == torch::kFloat16) {
probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
}
TORCH_CHECK(
((probs_or_mask.value().max() < INFINITY) &
(probs_or_mask.value().min() >= 0))
.item()
.to<bool>(),
"Invalid probs_or_mask (contains either `inf`, `nan` or element < 0).");
}
if (layer) {
......@@ -690,11 +696,91 @@ inline int64_t NonUniformPick(
return num_positive_probs;
} else {
if (!replace) fanout = std::min(fanout, num_positive_probs);
std::memcpy(
picked_data_ptr,
(torch::multinomial(local_probs, fanout, replace) + offset)
.data_ptr<PickedType>(),
fanout * sizeof(PickedType));
if (fanout == 0) return 0;
AT_DISPATCH_FLOATING_TYPES(
local_probs.scalar_type(), "MultinomialSampling", ([&] {
auto local_probs_data_ptr = local_probs.data_ptr<scalar_t>();
auto positive_probs_indices_ptr =
positive_probs_indices.data_ptr<PickedType>();
if (!replace) {
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1).
// Here we can apply exp to the formula which will not affect result
// of argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1).
if (fanout == 1) {
// Return argmax(p / q).
scalar_t max_prob = 0;
PickedType max_prob_index = -1;
// We only care about the neighbors with non-zero probability.
for (auto i = 0; i < num_positive_probs; ++i) {
// Calculate (p / q) for the current neighbor.
scalar_t current_prob =
local_probs_data_ptr[positive_probs_indices_ptr[i]] /
RandomEngine::ThreadLocal()->Exponential(1.);
if (current_prob > max_prob) {
max_prob = current_prob;
max_prob_index = positive_probs_indices_ptr[i];
}
}
*picked_data_ptr = max_prob_index + offset;
} else {
// Return topk(p / q).
std::vector<std::pair<scalar_t, PickedType>> q(
num_positive_probs);
for (auto i = 0; i < num_positive_probs; ++i) {
q[i].first =
local_probs_data_ptr[positive_probs_indices_ptr[i]] /
RandomEngine::ThreadLocal()->Exponential(1.);
q[i].second = positive_probs_indices_ptr[i];
}
if (fanout < num_positive_probs / 64) {
// Use partial_sort.
std::partial_sort(
q.begin(), q.begin() + fanout, q.end(), std::greater{});
for (auto i = 0; i < fanout; ++i) {
picked_data_ptr[i] = q[i].second + offset;
}
} else {
// Use nth_element.
std::nth_element(
q.begin(), q.begin() + fanout - 1, q.end(), std::greater{});
for (auto i = 0; i < fanout; ++i) {
picked_data_ptr[i] = q[i].second + offset;
}
}
}
} else {
// Calculate cumulative sum of probabilities.
std::vector<scalar_t> prefix_sum_probs(num_positive_probs);
scalar_t sum_probs = 0;
for (auto i = 0; i < num_positive_probs; ++i) {
sum_probs += local_probs_data_ptr[positive_probs_indices_ptr[i]];
prefix_sum_probs[i] = sum_probs;
}
// Normalize.
if ((sum_probs > 1.00001) || (sum_probs < 0.99999)) {
for (auto i = 0; i < num_positive_probs; ++i) {
prefix_sum_probs[i] /= sum_probs;
}
}
for (auto i = 0; i < fanout; ++i) {
// Sample a probability mass from a uniform distribution.
double uniform_sample =
RandomEngine::ThreadLocal()->Uniform(0., 1.);
// Use a binary search to find the index.
int sampled_index = std::lower_bound(
prefix_sum_probs.begin(),
prefix_sum_probs.end(), uniform_sample) -
prefix_sum_probs.begin();
picked_data_ptr[i] =
positive_probs_indices_ptr[sampled_index] + offset;
}
}
}));
return fanout;
}
}
......
......@@ -69,6 +69,25 @@ class RandomEngine {
return dist(rng_);
}
/**
* @brief Generate a uniform random real number in [low, high).
*/
template <typename T>
T Uniform(T lower, T upper) {
std::uniform_real_distribution<T> dist(lower, upper);
return dist(rng_);
}
/**
* @brief Generate random non-negative floating-point values according to
* exponential distribution. Probability density function: P(x|λ) = λe^(-λx).
*/
template <typename T>
T Exponential(T lambda) {
std::exponential_distribution<T> dist(lambda);
return dist(rng_);
}
private:
pcg32 rng_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment