Unverified Commit 11bdd6e8 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Refactor the nonuniform pick function to make it reusable. (#6772)

parent 3d657dbf
......@@ -818,6 +818,103 @@ inline int64_t UniformPick(
}
}
/** @brief An operator to perform non-uniform sampling. */
static torch::Tensor NonUniformPickOp(
torch::Tensor probs, int64_t fanout, bool replace) {
auto positive_probs_indices = probs.nonzero().squeeze(1);
auto num_positive_probs = positive_probs_indices.size(0);
if (num_positive_probs == 0) return torch::empty({0}, torch::kLong);
if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) {
return positive_probs_indices;
}
if (!replace) fanout = std::min(fanout, num_positive_probs);
if (fanout == 0) return torch::empty({0}, torch::kLong);
auto ret_tensor = torch::empty({fanout}, torch::kLong);
auto ret_ptr = ret_tensor.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(
probs.scalar_type(), "MultinomialSampling", ([&] {
auto probs_data_ptr = probs.data_ptr<scalar_t>();
auto positive_probs_indices_ptr =
positive_probs_indices.data_ptr<int64_t>();
if (!replace) {
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1).
// Here we can apply exp to the formula which will not affect result
// of argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1).
if (fanout == 1) {
// Return argmax(p / q).
scalar_t max_prob = 0;
int64_t max_prob_index = -1;
// We only care about the neighbors with non-zero probability.
for (auto i = 0; i < num_positive_probs; ++i) {
// Calculate (p / q) for the current neighbor.
scalar_t current_prob =
probs_data_ptr[positive_probs_indices_ptr[i]] /
RandomEngine::ThreadLocal()->Exponential(1.);
if (current_prob > max_prob) {
max_prob = current_prob;
max_prob_index = positive_probs_indices_ptr[i];
}
}
ret_ptr[0] = max_prob_index;
} else {
// Return topk(p / q).
std::vector<std::pair<scalar_t, int64_t>> q(num_positive_probs);
for (auto i = 0; i < num_positive_probs; ++i) {
q[i].first = probs_data_ptr[positive_probs_indices_ptr[i]] /
RandomEngine::ThreadLocal()->Exponential(1.);
q[i].second = positive_probs_indices_ptr[i];
}
if (fanout < num_positive_probs / 64) {
// Use partial_sort.
std::partial_sort(
q.begin(), q.begin() + fanout, q.end(), std::greater{});
for (auto i = 0; i < fanout; ++i) {
ret_ptr[i] = q[i].second;
}
} else {
// Use nth_element.
std::nth_element(
q.begin(), q.begin() + fanout - 1, q.end(), std::greater{});
for (auto i = 0; i < fanout; ++i) {
ret_ptr[i] = q[i].second;
}
}
}
} else {
// Calculate cumulative sum of probabilities.
std::vector<scalar_t> prefix_sum_probs(num_positive_probs);
scalar_t sum_probs = 0;
for (auto i = 0; i < num_positive_probs; ++i) {
sum_probs += probs_data_ptr[positive_probs_indices_ptr[i]];
prefix_sum_probs[i] = sum_probs;
}
// Normalize.
if ((sum_probs > 1.00001) || (sum_probs < 0.99999)) {
for (auto i = 0; i < num_positive_probs; ++i) {
prefix_sum_probs[i] /= sum_probs;
}
}
for (auto i = 0; i < fanout; ++i) {
// Sample a probability mass from a uniform distribution.
double uniform_sample =
RandomEngine::ThreadLocal()->Uniform(0., 1.);
// Use a binary search to find the index.
int sampled_index = std::lower_bound(
prefix_sum_probs.begin(),
prefix_sum_probs.end(), uniform_sample) -
prefix_sum_probs.begin();
ret_ptr[i] = positive_probs_indices_ptr[sampled_index];
}
}
}));
return ret_tensor;
}
/**
* @brief Perform non-uniform sampling of elements based on probabilities and
* return the sampled indices.
......@@ -861,104 +958,13 @@ inline int64_t NonUniformPick(
PickedType* picked_data_ptr) {
auto local_probs =
probs_or_mask.value().slice(0, offset, offset + num_neighbors);
auto positive_probs_indices = local_probs.nonzero().squeeze(1);
auto num_positive_probs = positive_probs_indices.size(0);
if (num_positive_probs == 0) return 0;
if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) {
std::memcpy(
picked_data_ptr,
(positive_probs_indices + offset).data_ptr<PickedType>(),
num_positive_probs * sizeof(PickedType));
return num_positive_probs;
} else {
if (!replace) fanout = std::min(fanout, num_positive_probs);
if (fanout == 0) return 0;
AT_DISPATCH_FLOATING_TYPES(
local_probs.scalar_type(), "MultinomialSampling", ([&] {
auto local_probs_data_ptr = local_probs.data_ptr<scalar_t>();
auto positive_probs_indices_ptr =
positive_probs_indices.data_ptr<PickedType>();
if (!replace) {
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1).
// Here we can apply exp to the formula which will not affect result
// of argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1).
if (fanout == 1) {
// Return argmax(p / q).
scalar_t max_prob = 0;
PickedType max_prob_index = -1;
// We only care about the neighbors with non-zero probability.
for (auto i = 0; i < num_positive_probs; ++i) {
// Calculate (p / q) for the current neighbor.
scalar_t current_prob =
local_probs_data_ptr[positive_probs_indices_ptr[i]] /
RandomEngine::ThreadLocal()->Exponential(1.);
if (current_prob > max_prob) {
max_prob = current_prob;
max_prob_index = positive_probs_indices_ptr[i];
}
}
*picked_data_ptr = max_prob_index + offset;
} else {
// Return topk(p / q).
std::vector<std::pair<scalar_t, PickedType>> q(
num_positive_probs);
for (auto i = 0; i < num_positive_probs; ++i) {
q[i].first =
local_probs_data_ptr[positive_probs_indices_ptr[i]] /
RandomEngine::ThreadLocal()->Exponential(1.);
q[i].second = positive_probs_indices_ptr[i];
}
if (fanout < num_positive_probs / 64) {
// Use partial_sort.
std::partial_sort(
q.begin(), q.begin() + fanout, q.end(), std::greater{});
for (auto i = 0; i < fanout; ++i) {
picked_data_ptr[i] = q[i].second + offset;
}
} else {
// Use nth_element.
std::nth_element(
q.begin(), q.begin() + fanout - 1, q.end(), std::greater{});
for (auto i = 0; i < fanout; ++i) {
picked_data_ptr[i] = q[i].second + offset;
}
}
}
} else {
// Calculate cumulative sum of probabilities.
std::vector<scalar_t> prefix_sum_probs(num_positive_probs);
scalar_t sum_probs = 0;
for (auto i = 0; i < num_positive_probs; ++i) {
sum_probs += local_probs_data_ptr[positive_probs_indices_ptr[i]];
prefix_sum_probs[i] = sum_probs;
}
// Normalize.
if ((sum_probs > 1.00001) || (sum_probs < 0.99999)) {
for (auto i = 0; i < num_positive_probs; ++i) {
prefix_sum_probs[i] /= sum_probs;
}
}
for (auto i = 0; i < fanout; ++i) {
// Sample a probability mass from a uniform distribution.
double uniform_sample =
RandomEngine::ThreadLocal()->Uniform(0., 1.);
// Use a binary search to find the index.
int sampled_index = std::lower_bound(
prefix_sum_probs.begin(),
prefix_sum_probs.end(), uniform_sample) -
prefix_sum_probs.begin();
picked_data_ptr[i] =
positive_probs_indices_ptr[sampled_index] + offset;
}
}
}));
return fanout;
auto picked_indices = NonUniformPickOp(local_probs, fanout, replace);
auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();
for (int i = 0; i < picked_indices.numel(); ++i) {
picked_data_ptr[i] =
static_cast<PickedType>(picked_indices_ptr[i]) + offset;
}
return picked_indices.numel();
}
template <typename PickedType>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment