"vscode:/vscode.git/clone" did not exist on "4433a5b2e7570ca86b08e0c229e3ac5bed8046bb"
Commit 4116005f authored by rusty1s's avatar rusty1s
Browse files

fix neighbor sampling

parent 69fada5e
......@@ -15,9 +15,10 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
auto num_neighbors = row_end - row_start;
int64_t size = count;
if (count < 1) {
if (count < 1)
size = int64_t(ceil(factor * float(num_neighbors)));
}
if (size > num_neighbors)
size = num_neighbors;
// If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without
......@@ -26,16 +27,16 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
std::unordered_set<int64_t> set;
if (size < 0.7 * float(num_neighbors)) {
while (int64_t(set.size()) < size) {
int64_t sample = (rand() % num_neighbors) + row_start;
set.insert(sample);
int64_t sample = rand() % num_neighbors;
set.insert(sample + row_start);
}
std::vector<int64_t> v(set.begin(), set.end());
e_ids.insert(e_ids.end(), v.begin(), v.end());
} else {
auto sample = at::randperm(num_neighbors, start.options()) + row_start;
auto sample = torch::randperm(num_neighbors, start.options());
auto sample_data = sample.data_ptr<int64_t>();
for (auto j = 0; j < size; j++) {
e_ids.push_back(sample_data[j]);
e_ids.push_back(sample_data[j] + row_start);
}
}
}
......
......@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_cluster',
version='1.5.1',
version='1.5.2',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_cluster',
......
......@@ -3,7 +3,7 @@ import os.path as osp
import torch
__version__ = '1.5.1'
__version__ = '1.5.2'
expected_torch_version = (1, 4)
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment