Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
53959eee
Commit
53959eee
authored
Dec 07, 2020
by
rusty1s
Browse files
clean up
parent
5f1939fd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
37 additions
and
25 deletions
+37
-25
csrc/cpu/fps_cpu.cpp
csrc/cpu/fps_cpu.cpp
+2
-2
csrc/cuda/fps_cuda.cu
csrc/cuda/fps_cuda.cu
+5
-6
csrc/cuda/fps_cuda.h
csrc/cuda/fps_cuda.h
+2
-2
setup.py
setup.py
+1
-1
test/test_fps.py
test/test_fps.py
+18
-4
torch_cluster/fps.py
torch_cluster/fps.py
+9
-10
No files found.
csrc/cpu/fps_cpu.cpp
View file @
53959eee
...
...
@@ -13,12 +13,12 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
CHECK_CPU
(
src
);
CHECK_CPU
(
ptr
);
CHECK_CPU
(
ratio
);
CHECK_INPUT
(
ptr
.
dim
()
==
1
);
// AT_ASSERTM(at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1))), "Invalid input");
src
=
src
.
view
({
src
.
size
(
0
),
-
1
}).
contiguous
();
ptr
=
ptr
.
contiguous
();
auto
batch_size
=
ptr
.
size
(
0
)
-
1
;
auto
batch_size
=
ptr
.
numel
(
)
-
1
;
auto
deg
=
ptr
.
narrow
(
0
,
1
,
batch_size
)
-
ptr
.
narrow
(
0
,
0
,
batch_size
);
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
ratio
;
...
...
csrc/cuda/fps_cuda.cu
View file @
53959eee
...
...
@@ -3,7 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#include <stdio.h>
#define THREADS 256
template
<
typename
scalar_t
>
...
...
@@ -64,19 +64,18 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
}
}
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
torch
::
Tensor
ratio
,
bool
random_start
)
{
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
torch
::
Tensor
ratio
,
bool
random_start
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
ptr
);
CHECK_CUDA
(
ratio
);
CHECK_INPUT
(
ptr
.
dim
()
==
1
);
// AT_ASSERTM(at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1))), "Invalid input");
cudaSetDevice
(
src
.
get_device
());
src
=
src
.
view
({
src
.
size
(
0
),
-
1
}).
contiguous
();
ptr
=
ptr
.
contiguous
();
ratio
=
ratio
.
contiguous
();
auto
batch_size
=
ptr
.
size
(
0
)
-
1
;
auto
batch_size
=
ptr
.
numel
()
-
1
;
auto
deg
=
ptr
.
narrow
(
0
,
1
,
batch_size
)
-
ptr
.
narrow
(
0
,
0
,
batch_size
);
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
ratio
;
...
...
csrc/cuda/fps_cuda.h
View file @
53959eee
...
...
@@ -2,5 +2,5 @@
#include <torch/extension.h>
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
torch
::
Tensor
ratio
,
bool
random_start
);
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
torch
::
Tensor
ratio
,
bool
random_start
);
setup.py
View file @
53959eee
...
...
@@ -98,7 +98,7 @@ setup(
ext_modules
=
get_extensions
()
if
not
BUILD_DOCS
else
[],
cmdclass
=
{
'build_ext'
:
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
)
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
},
packages
=
find_packages
(),
)
test/test_fps.py
View file @
53959eee
...
...
@@ -21,16 +21,30 @@ def test_fps(dtype, device):
],
dtype
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
out
=
fps
(
x
,
batch
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
batch
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
batch
,
ratio
=
torch
.
tensor
(
0.5
),
random_start
=
False
)
out
=
fps
(
x
,
batch
,
ratio
=
torch
.
tensor
(
0.5
,
device
=
device
),
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
batch
,
ratio
=
torch
.
tensor
([
0.5
,
0.5
]),
random_start
=
False
)
out
=
fps
(
x
,
batch
,
ratio
=
torch
.
tensor
([
0.5
,
0.5
],
device
=
device
),
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
ratio
=
torch
.
tensor
(
0.5
),
random_start
=
False
)
out
=
fps
(
x
,
random_start
=
False
)
assert
out
.
sort
()[
0
].
tolist
()
==
[
0
,
5
,
6
,
7
]
out
=
fps
(
x
,
ratio
=
0.5
,
random_start
=
False
)
assert
out
.
sort
()[
0
].
tolist
()
==
[
0
,
5
,
6
,
7
]
out
=
fps
(
x
,
ratio
=
torch
.
tensor
(
0.5
,
device
=
device
),
random_start
=
False
)
assert
out
.
sort
()[
0
].
tolist
()
==
[
0
,
5
,
6
,
7
]
out
=
fps
(
x
,
ratio
=
torch
.
tensor
([
0.5
],
device
=
device
),
random_start
=
False
)
assert
out
.
sort
()[
0
].
tolist
()
==
[
0
,
5
,
6
,
7
]
...
...
@@ -42,5 +56,5 @@ def test_random_fps(device):
batch_1
=
torch
.
zeros
(
N
,
dtype
=
torch
.
long
,
device
=
device
)
batch_2
=
torch
.
ones
(
N
,
dtype
=
torch
.
long
,
device
=
device
)
batch
=
torch
.
cat
([
batch_1
,
batch_2
])
idx
=
fps
(
pos
,
batch
,
ratio
=
torch
.
tensor
(
0.5
)
)
idx
=
fps
(
pos
,
batch
,
ratio
=
0.5
)
assert
idx
.
min
()
>=
0
and
idx
.
max
()
<
2
*
N
torch_cluster/fps.py
View file @
53959eee
...
...
@@ -3,19 +3,19 @@ from torch import Tensor
import
torch
@
torch
.
jit
.
_overload
def
fps
(
src
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
):
# type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tensor
pass
@
torch
.
jit
.
_overload
def
fps
(
src
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
@
torch
.
jit
.
_overload
# noqa
def
fps
(
src
,
batch
,
ratio
,
random_start
):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass
def
fps
(
src
:
torch
.
Tensor
,
batch
=
None
,
ratio
=
None
,
random_start
=
True
):
def
fps
(
src
:
torch
.
Tensor
,
batch
=
None
,
ratio
=
0.5
,
random_start
=
True
):
# noqa
r
""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...
...
@@ -27,12 +27,14 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True):
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
ratio (Tensor, optional): Sampling ratio. (default: :obj:`0.5`)
ratio (float or Tensor, optional): Sampling ratio.
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
...
...
@@ -44,10 +46,7 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True):
"""
if
not
isinstance
(
ratio
,
Tensor
):
ratio
=
torch
.
tensor
(
ratio
)
assert
len
(
ratio
.
shape
)
<
2
,
f
'ratio should be a scalar or a vector, received a tensor rank
{
len
(
ratio
.
shape
)
}
'
ratio
=
ratio
.
to
(
src
.
device
)
ratio
=
torch
.
tensor
(
ratio
,
device
=
src
.
device
)
if
batch
is
not
None
:
assert
src
.
size
(
0
)
==
batch
.
numel
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment