Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
b0f9f81b
Commit
b0f9f81b
authored
Dec 18, 2018
by
rusty1s
Browse files
fps cpu version
parent
0a038334
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
32 deletions
+7
-32
cuda/fps_kernel.cu
cuda/fps_kernel.cu
+1
-1
test/test_fps.py
test/test_fps.py
+1
-26
torch_cluster/fps.py
torch_cluster/fps.py
+5
-5
No files found.
cuda/fps_kernel.cu
View file @
b0f9f81b
...
@@ -169,7 +169,7 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
...
@@ -169,7 +169,7 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
auto
deg
=
degree
(
batch
,
batch_size
);
auto
deg
=
degree
(
batch
,
batch_size
);
auto
cum_deg
=
at
::
cat
({
at
::
zeros
(
1
,
deg
.
options
()),
deg
.
cumsum
(
0
)},
0
);
auto
cum_deg
=
at
::
cat
({
at
::
zeros
(
1
,
deg
.
options
()),
deg
.
cumsum
(
0
)},
0
);
auto
k
=
(
deg
.
toType
(
at
::
kFloat
)
*
ratio
).
round
().
toType
(
at
::
kLong
);
auto
k
=
(
deg
.
toType
(
at
::
kFloat
)
*
ratio
).
ceil
().
toType
(
at
::
kLong
);
auto
cum_k
=
at
::
cat
({
at
::
zeros
(
1
,
k
.
options
()),
k
.
cumsum
(
0
)},
0
);
auto
cum_k
=
at
::
cat
({
at
::
zeros
(
1
,
k
.
options
()),
k
.
cumsum
(
0
)},
0
);
at
::
Tensor
start
;
at
::
Tensor
start
;
...
...
test/test_fps.py
View file @
b0f9f81b
...
@@ -4,12 +4,9 @@ import pytest
...
@@ -4,12 +4,9 @@ import pytest
import
torch
import
torch
from
torch_cluster
import
fps
from
torch_cluster
import
fps
from
.utils
import
tensor
,
grad_dtypes
from
.utils
import
grad_dtypes
,
devices
,
tensor
devices
=
[
torch
.
device
(
'cuda'
)]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_fps
(
dtype
,
device
):
def
test_fps
(
dtype
,
device
):
x
=
tensor
([
x
=
tensor
([
...
@@ -26,25 +23,3 @@ def test_fps(dtype, device):
...
@@ -26,25 +23,3 @@ def test_fps(dtype, device):
out
=
fps
(
x
,
batch
,
ratio
=
0.5
,
random_start
=
False
)
out
=
fps
(
x
,
batch
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_fps_speed
(
dtype
,
device
):
return
batch_size
,
num_nodes
=
100
,
10000
x
=
torch
.
randn
((
batch_size
*
num_nodes
,
3
),
dtype
=
dtype
,
device
=
device
)
batch
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
,
device
=
device
)
batch
=
batch
.
view
(
-
1
,
1
).
repeat
(
1
,
num_nodes
).
view
(
-
1
)
out
=
fps
(
x
,
batch
,
ratio
=
0.5
,
random_start
=
True
)
assert
out
.
size
(
0
)
==
batch_size
*
num_nodes
*
0.5
assert
out
.
min
().
item
()
>=
0
and
out
.
max
().
item
()
<
batch_size
*
num_nodes
batch_size
,
num_nodes
,
dim
=
100
,
300
,
128
x
=
torch
.
randn
((
batch_size
*
num_nodes
,
dim
),
dtype
=
dtype
,
device
=
device
)
batch
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
,
device
=
device
)
batch
=
batch
.
view
(
-
1
,
1
).
repeat
(
1
,
num_nodes
).
view
(
-
1
)
out
=
fps
(
x
,
batch
,
ratio
=
0.5
,
random_start
=
True
)
assert
out
.
size
(
0
)
==
batch_size
*
num_nodes
*
0.5
assert
out
.
min
().
item
()
>=
0
and
out
.
max
().
item
()
<
batch_size
*
num_nodes
torch_cluster/fps.py
View file @
b0f9f81b
import
torch
import
torch
import
fps_cpu
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
import
fps_cuda
import
fps_cuda
...
@@ -39,12 +40,11 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
...
@@ -39,12 +40,11 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
assert
x
.
is_cuda
assert
x
.
dim
()
==
2
and
batch
.
dim
()
==
1
assert
x
.
dim
()
==
2
and
batch
.
dim
()
==
1
assert
x
.
size
(
0
)
==
batch
.
size
(
0
)
assert
x
.
size
(
0
)
==
batch
.
size
(
0
)
assert
ratio
>
0
and
ratio
<
1
assert
ratio
>
0
and
ratio
<
1
op
=
fps_cuda
.
fps
if
x
.
is_cuda
else
None
if
x
.
is_cuda
:
out
=
op
(
x
,
batch
,
ratio
,
random_start
)
return
fps_cuda
.
fps
(
x
,
batch
,
ratio
,
random_start
)
else
:
return
out
return
fps_cpu
.
fps
(
x
,
batch
,
ratio
,
random_start
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment