Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
4116005f
Commit
4116005f
authored
Mar 17, 2020
by
rusty1s
Browse files
fix neighbor sampling
parent
69fada5e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
8 deletions
+9
-8
csrc/cpu/sampler_cpu.cpp
csrc/cpu/sampler_cpu.cpp
+7
-6
setup.py
setup.py
+1
-1
torch_cluster/__init__.py
torch_cluster/__init__.py
+1
-1
No files found.
csrc/cpu/sampler_cpu.cpp
View file @
4116005f
...
...
@@ -15,9 +15,10 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
auto
num_neighbors
=
row_end
-
row_start
;
int64_t
size
=
count
;
if
(
count
<
1
)
{
if
(
count
<
1
)
size
=
int64_t
(
ceil
(
factor
*
float
(
num_neighbors
)));
}
if
(
size
>
num_neighbors
)
size
=
num_neighbors
;
// If the number of neighbors is approximately equal to the number of
// neighbors which are requested, we use `randperm` to sample without
...
...
@@ -26,16 +27,16 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
std
::
unordered_set
<
int64_t
>
set
;
if
(
size
<
0.7
*
float
(
num_neighbors
))
{
while
(
int64_t
(
set
.
size
())
<
size
)
{
int64_t
sample
=
(
rand
()
%
num_neighbors
)
+
row_start
;
set
.
insert
(
sample
);
int64_t
sample
=
rand
()
%
num_neighbors
;
set
.
insert
(
sample
+
row_start
);
}
std
::
vector
<
int64_t
>
v
(
set
.
begin
(),
set
.
end
());
e_ids
.
insert
(
e_ids
.
end
(),
v
.
begin
(),
v
.
end
());
}
else
{
auto
sample
=
a
t
::
randperm
(
num_neighbors
,
start
.
options
())
+
row_start
;
auto
sample
=
t
orch
::
randperm
(
num_neighbors
,
start
.
options
());
auto
sample_data
=
sample
.
data_ptr
<
int64_t
>
();
for
(
auto
j
=
0
;
j
<
size
;
j
++
)
{
e_ids
.
push_back
(
sample_data
[
j
]);
e_ids
.
push_back
(
sample_data
[
j
]
+
row_start
);
}
}
}
...
...
setup.py
View file @
4116005f
...
...
@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov']
setup
(
name
=
'torch_cluster'
,
version
=
'1.5.
1
'
,
version
=
'1.5.
2
'
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
url
=
'https://github.com/rusty1s/pytorch_cluster'
,
...
...
torch_cluster/__init__.py
View file @
4116005f
...
...
@@ -3,7 +3,7 @@ import os.path as osp
import
torch
__version__
=
'1.5.
1
'
__version__
=
'1.5.
2
'
expected_torch_version
=
(
1
,
4
)
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment