Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
882c8e08
Commit
882c8e08
authored
Dec 07, 2020
by
rusty1s
Browse files
test jit script
parent
53959eee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
8 deletions
+24
-8
test/test_fps.py
test/test_fps.py
+9
-0
torch_cluster/fps.py
torch_cluster/fps.py
+15
-8
No files found.
test/test_fps.py
View file @
882c8e08
...
...
@@ -2,11 +2,17 @@ from itertools import product
import
pytest
import
torch
from
torch
import
Tensor
from
torch_cluster
import
fps
from
.utils
import
grad_dtypes
,
devices
,
tensor
@
torch
.
jit
.
script
def
fps2
(
x
:
Tensor
,
ratio
:
Tensor
)
->
Tensor
:
return
fps
(
x
,
None
,
ratio
,
False
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_fps
(
dtype
,
device
):
x
=
tensor
([
...
...
@@ -47,6 +53,9 @@ def test_fps(dtype, device):
out
=
fps
(
x
,
ratio
=
torch
.
tensor
([
0.5
],
device
=
device
),
random_start
=
False
)
assert
out
.
sort
()[
0
].
tolist
()
==
[
0
,
5
,
6
,
7
]
out
=
fps2
(
x
,
torch
.
tensor
([
0.5
],
device
=
device
))
assert
out
.
sort
()[
0
].
tolist
()
==
[
0
,
5
,
6
,
7
]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_random_fps
(
device
):
...
...
torch_cluster/fps.py
View file @
882c8e08
from
typing
import
Optional
from
torch
import
Tensor
import
torch
from
torch
import
Tensor
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
):
# type: (Tensor, Optional[Tensor], Optional[
in
t], bool) -> Tensor
def
fps
(
src
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
# type: (Tensor, Optional[Tensor], Optional[
floa
t], bool) -> Tensor
pass
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
):
def
fps
(
src
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass
def
fps
(
src
:
torch
.
Tensor
,
batch
=
None
,
ratio
=
0.5
,
random_start
=
True
):
# noqa
def
fps
(
src
:
torch
.
Tensor
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
# noqa
r
""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...
...
@@ -45,8 +46,14 @@ def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa
index = fps(src, batch, ratio=0.5)
"""
if
not
isinstance
(
ratio
,
Tensor
):
ratio
=
torch
.
tensor
(
ratio
,
device
=
src
.
device
)
r
:
Optional
[
Tensor
]
=
None
if
ratio
is
None
:
r
=
torch
.
tensor
(
0.5
,
dtype
=
src
.
dtype
,
device
=
src
.
device
)
elif
isinstance
(
ratio
,
float
):
r
=
torch
.
tensor
(
ratio
,
dtype
=
src
.
dtype
,
device
=
src
.
device
)
else
:
r
=
ratio
assert
r
is
not
None
if
batch
is
not
None
:
assert
src
.
size
(
0
)
==
batch
.
numel
()
...
...
@@ -60,4 +67,4 @@ def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa
else
:
ptr
=
torch
.
tensor
([
0
,
src
.
size
(
0
)],
device
=
src
.
device
)
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
ptr
,
r
atio
,
random_start
)
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
ptr
,
r
,
random_start
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment