Unverified Commit e465ac09 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #23 from Dawars/patch-2

Fixes breakage with PyTorch nightly #22
parents f7e9e8fd 88a60c40
#include <TH/THRandom.h>
#include <ATen/CPUGenerator.h>
#include <torch/extension.h>
#include <TH/THGenerator.hpp>
at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
float factor) {
THGenerator *generator = THGenerator_new();
CPUGenerator* generator = at::detail::getDefaultCPUGenerator();
auto start_ptr = start.data<int64_t>();
auto cumdeg_ptr = cumdeg.data<int64_t>();
......@@ -26,7 +24,7 @@ at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
std::unordered_set<int64_t> set;
if (size_i < 0.7 * float(num_neighbors)) {
while (set.size() < size_i) {
int64_t z = THRandom_random(generator) % num_neighbors;
int64_t z = generator->random() % num_neighbors;
set.insert(z + low);
}
std::vector<int64_t> v(set.begin(), set.end());
......@@ -40,8 +38,6 @@ at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
}
}
THGenerator_free(generator);
int64_t len = e_ids.size();
auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone();
return e_id;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment