#include
#include
#include |
at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
float factor) {
THGenerator *generator = THGenerator_new();
auto start_ptr = start.data();
auto cumdeg_ptr = cumdeg.data();
std::vector e_ids;
for (ptrdiff_t i = 0; i < start.size(0); i++) {
int64_t low = cumdeg_ptr[start_ptr[i]];
int64_t high = cumdeg_ptr[start_ptr[i] + 1];
size_t num_neighbors = high - low;
size_t size_i = size_t(ceil(factor * float(num_neighbors)));
size_i = (size_i < size) ? size_i : size;
// If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without
// replacement, otherwise we sample random numbers into a set as long as
// necessary.
std::unordered_set set;
if (size_i < 0.7 * float(num_neighbors)) {
while (set.size() < size_i) {
int64_t z = THRandom_random(generator) % num_neighbors;
set.insert(z + low);
}
std::vector v(set.begin(), set.end());
e_ids.insert(e_ids.end(), v.begin(), v.end());
} else {
auto sample = at::randperm(num_neighbors, start.options());
auto sample_ptr = sample.data();
for (size_t j = 0; j < size_i; j++) {
e_ids.push_back(sample_ptr[j] + low);
}
}
}
THGenerator_free(generator);
int64_t len = e_ids.size();
auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone();
return e_id;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("neighbor_sampler", &neighbor_sampler, "Neighbor Sampler (CPU)");
}
| |