Commit 4116005f authored by rusty1s's avatar rusty1s
Browse files

fix neighbor sampling

parent 69fada5e
...@@ -15,9 +15,10 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr, ...@@ -15,9 +15,10 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
auto num_neighbors = row_end - row_start; auto num_neighbors = row_end - row_start;
int64_t size = count; int64_t size = count;
if (count < 1) { if (count < 1)
size = int64_t(ceil(factor * float(num_neighbors))); size = int64_t(ceil(factor * float(num_neighbors)));
} if (size > num_neighbors)
size = num_neighbors;
// If the number of neighbors is approximately equal to the number of // If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without // neighbors which are requested, we use `randperm` to sample without
...@@ -26,16 +27,16 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr, ...@@ -26,16 +27,16 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
std::unordered_set<int64_t> set; std::unordered_set<int64_t> set;
if (size < 0.7 * float(num_neighbors)) { if (size < 0.7 * float(num_neighbors)) {
while (int64_t(set.size()) < size) { while (int64_t(set.size()) < size) {
int64_t sample = (rand() % num_neighbors) + row_start; int64_t sample = rand() % num_neighbors;
set.insert(sample); set.insert(sample + row_start);
} }
std::vector<int64_t> v(set.begin(), set.end()); std::vector<int64_t> v(set.begin(), set.end());
e_ids.insert(e_ids.end(), v.begin(), v.end()); e_ids.insert(e_ids.end(), v.begin(), v.end());
} else { } else {
auto sample = at::randperm(num_neighbors, start.options()) + row_start; auto sample = torch::randperm(num_neighbors, start.options());
auto sample_data = sample.data_ptr<int64_t>(); auto sample_data = sample.data_ptr<int64_t>();
for (auto j = 0; j < size; j++) { for (auto j = 0; j < size; j++) {
e_ids.push_back(sample_data[j]); e_ids.push_back(sample_data[j] + row_start);
} }
} }
} }
......
...@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov'] ...@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov']
setup( setup(
name='torch_cluster', name='torch_cluster',
version='1.5.1', version='1.5.2',
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_cluster', url='https://github.com/rusty1s/pytorch_cluster',
......
...@@ -3,7 +3,7 @@ import os.path as osp ...@@ -3,7 +3,7 @@ import os.path as osp
import torch import torch
__version__ = '1.5.1' __version__ = '1.5.2'
expected_torch_version = (1, 4) expected_torch_version = (1, 4)
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment