Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
9c33077e
"examples/vscode:/vscode.git/clone" did not exist on "e828232780554d54bdb527d3390fab0be042b72a"
Commit
9c33077e
authored
Mar 24, 2020
by
rusty1s
Browse files
fix fps implementation
parent
aff91e0e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
21 additions
and
9 deletions
+21
-9
csrc/cpu/fps_cpu.cpp
csrc/cpu/fps_cpu.cpp
+1
-1
csrc/cuda/fps_cuda.cu
csrc/cuda/fps_cuda.cu
+6
-6
setup.py
setup.py
+1
-1
test/test_fps.py
test/test_fps.py
+12
-0
torch_cluster/__init__.py
torch_cluster/__init__.py
+1
-1
No files found.
csrc/cpu/fps_cpu.cpp
View file @
9c33077e
...
@@ -35,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
...
@@ -35,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
int64_t
start_idx
=
0
;
int64_t
start_idx
=
0
;
if
(
random_start
)
{
if
(
random_start
)
{
start_idx
=
rand
()
%
src
.
size
(
0
);
start_idx
=
rand
()
%
y
.
size
(
0
);
}
}
out_data
[
out_start
]
=
src_start
+
start_idx
;
out_data
[
out_start
]
=
src_start
+
start_idx
;
...
...
csrc/cuda/fps_cuda.cu
View file @
9c33077e
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "utils.cuh"
#include "utils.cuh"
#define THREADS
1024
#define THREADS
256
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
fps_kernel
(
const
scalar_t
*
src
,
const
int64_t
*
ptr
,
__global__
void
fps_kernel
(
const
scalar_t
*
src
,
const
int64_t
*
ptr
,
...
@@ -31,15 +31,15 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
...
@@ -31,15 +31,15 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
int64_t
best_idx
=
0
;
int64_t
best_idx
=
0
;
for
(
int64_t
n
=
start_idx
+
thread_idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
for
(
int64_t
n
=
start_idx
+
thread_idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
scalar_t
tmp
;
scalar_t
tmp
,
dd
=
(
scalar_t
)
0.
;
scalar_t
dd
=
(
scalar_t
)
0.
;
for
(
int64_t
d
=
0
;
d
<
dim
;
d
++
)
{
for
(
int64_t
d
=
0
;
d
<
dim
;
d
++
)
{
tmp
=
src
[
dim
*
old
+
d
]
-
src
[
dim
*
n
+
d
];
tmp
=
src
[
dim
*
old
+
d
]
-
src
[
dim
*
n
+
d
];
dd
+=
tmp
*
tmp
;
dd
+=
tmp
*
tmp
;
}
}
dist
[
n
]
=
min
(
dist
[
n
],
dd
);
dd
=
min
(
dist
[
n
],
dd
);
if
(
dist
[
n
]
>
best
)
{
dist
[
n
]
=
dd
;
best
=
dist
[
n
];
if
(
dd
>
best
)
{
best
=
dd
;
best_idx
=
n
;
best_idx
=
n
;
}
}
}
}
...
...
setup.py
View file @
9c33077e
...
@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov']
...
@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov']
setup
(
setup
(
name
=
'torch_cluster'
,
name
=
'torch_cluster'
,
version
=
'1.5.
2
'
,
version
=
'1.5.
3
'
,
author
=
'Matthias Fey'
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
url
=
'https://github.com/rusty1s/pytorch_cluster'
,
url
=
'https://github.com/rusty1s/pytorch_cluster'
,
...
...
test/test_fps.py
View file @
9c33077e
...
@@ -26,3 +26,15 @@ def test_fps(dtype, device):
...
@@ -26,3 +26,15 @@ def test_fps(dtype, device):
out
=
fps
(
x
,
ratio
=
0.5
,
random_start
=
False
)
out
=
fps
(
x
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
sort
()[
0
].
tolist
()
==
[
0
,
5
,
6
,
7
]
assert
out
.
sort
()[
0
].
tolist
()
==
[
0
,
5
,
6
,
7
]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_random_fps
(
device
):
N
=
1024
for
_
in
range
(
5
):
pos
=
torch
.
randn
((
2
*
N
,
3
),
device
=
device
)
batch_1
=
torch
.
zeros
(
N
,
dtype
=
torch
.
long
,
device
=
device
)
batch_2
=
torch
.
ones
(
N
,
dtype
=
torch
.
long
,
device
=
device
)
batch
=
torch
.
cat
([
batch_1
,
batch_2
])
idx
=
fps
(
pos
,
batch
,
ratio
=
0.5
)
assert
idx
.
min
()
>=
0
and
idx
.
max
()
<
2
*
N
torch_cluster/__init__.py
View file @
9c33077e
...
@@ -3,7 +3,7 @@ import os.path as osp
...
@@ -3,7 +3,7 @@ import os.path as osp
import
torch
import
torch
__version__
=
'1.5.
2
'
__version__
=
'1.5.
3
'
expected_torch_version
=
(
1
,
4
)
expected_torch_version
=
(
1
,
4
)
try
:
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment