Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
82968b99
Commit
82968b99
authored
Sep 22, 2020
by
Duc Nguyen
Browse files
accepted different ratios for different point clouds
parent
2bf5e763
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
22 additions
and
16 deletions
+22
-16
csrc/cpu/fps_cpu.cpp
csrc/cpu/fps_cpu.cpp
+3
-3
csrc/cpu/fps_cpu.h
csrc/cpu/fps_cpu.h
+1
-1
csrc/cuda/fps_cuda.cu
csrc/cuda/fps_cuda.cu
+5
-4
csrc/cuda/fps_cuda.h
csrc/cuda/fps_cuda.h
+1
-1
csrc/fps.cpp
csrc/fps.cpp
+1
-1
setup.py
setup.py
+1
-1
test/test_fps.py
test/test_fps.py
+3
-3
torch_cluster/fps.py
torch_cluster/fps.py
+7
-2
No files found.
csrc/cpu/fps_cpu.cpp
View file @
82968b99
...
...
@@ -6,20 +6,20 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return
(
x
-
x
[
idx
]).
norm
(
2
,
1
);
}
torch
::
Tensor
fps_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
torch
::
Tensor
fps_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
torch
::
Tensor
ratio
,
bool
random_start
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
ptr
);
CHECK_INPUT
(
ptr
.
dim
()
==
1
);
AT_ASSERTM
(
ratio
>
0
&&
ratio
<
1
,
"Invalid input"
);
//
AT_ASSERTM(
at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1)))
, "Invalid input");
src
=
src
.
view
({
src
.
size
(
0
),
-
1
}).
contiguous
();
ptr
=
ptr
.
contiguous
();
auto
batch_size
=
ptr
.
size
(
0
)
-
1
;
auto
deg
=
ptr
.
narrow
(
0
,
1
,
batch_size
)
-
ptr
.
narrow
(
0
,
0
,
batch_size
);
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
(
float
)
ratio
;
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
ratio
;
out_ptr
=
out_ptr
.
ceil
().
toType
(
torch
::
kLong
).
cumsum
(
0
);
auto
out
=
torch
::
empty
(
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
()[
0
],
ptr
.
options
());
...
...
csrc/cpu/fps_cpu.h
View file @
82968b99
...
...
@@ -2,5 +2,5 @@
#include <torch/extension.h>
torch
::
Tensor
fps_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
torch
::
Tensor
fps_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
torch
::
Tensor
ratio
,
bool
random_start
);
csrc/cuda/fps_cuda.cu
View file @
82968b99
...
...
@@ -3,7 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#include <stdio.h>
#define THREADS 256
template
<
typename
scalar_t
>
...
...
@@ -64,21 +64,22 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
}
}
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
torch
::
Tensor
ratio
,
bool
random_start
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
ptr
);
CHECK_INPUT
(
ptr
.
dim
()
==
1
);
AT_ASSERTM
(
ratio
>
0
&&
ratio
<
1
,
"Invalid input"
);
//
AT_ASSERTM(
at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1)))
, "Invalid input");
cudaSetDevice
(
src
.
get_device
());
src
=
src
.
view
({
src
.
size
(
0
),
-
1
}).
contiguous
();
ptr
=
ptr
.
contiguous
();
ratio
=
ratio
.
contiguous
();
auto
batch_size
=
ptr
.
size
(
0
)
-
1
;
auto
deg
=
ptr
.
narrow
(
0
,
1
,
batch_size
)
-
ptr
.
narrow
(
0
,
0
,
batch_size
);
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
(
float
)
ratio
;
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
ratio
;
out_ptr
=
out_ptr
.
ceil
().
toType
(
torch
::
kLong
).
cumsum
(
0
);
out_ptr
=
torch
::
cat
({
torch
::
zeros
(
1
,
ptr
.
options
()),
out_ptr
},
0
);
...
...
csrc/cuda/fps_cuda.h
View file @
82968b99
...
...
@@ -2,5 +2,5 @@
#include <torch/extension.h>
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
torch
::
Tensor
ratio
,
bool
random_start
);
csrc/fps.cpp
View file @
82968b99
...
...
@@ -11,7 +11,7 @@
PyMODINIT_FUNC
PyInit__fps
(
void
)
{
return
NULL
;
}
#endif
torch
::
Tensor
fps
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
torch
::
Tensor
fps
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
torch
::
Tensor
ratio
,
bool
random_start
)
{
if
(
src
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
...
...
setup.py
View file @
82968b99
...
...
@@ -84,7 +84,7 @@ setup(
ext_modules
=
get_extensions
()
if
not
BUILD_DOCS
else
[],
cmdclass
=
{
'build_ext'
:
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
)
},
packages
=
find_packages
(),
)
test/test_fps.py
View file @
82968b99
...
...
@@ -21,10 +21,10 @@ def test_fps(dtype, device):
],
dtype
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
out
=
fps
(
x
,
batch
,
ratio
=
0.5
,
random_start
=
False
)
out
=
fps
(
x
,
batch
,
ratio
=
torch
.
tensor
(
0.5
)
,
random_start
=
False
)
assert
out
.
tolist
()
==
[
0
,
2
,
4
,
6
]
out
=
fps
(
x
,
ratio
=
0.5
,
random_start
=
False
)
out
=
fps
(
x
,
ratio
=
torch
.
tensor
(
0.5
)
,
random_start
=
False
)
assert
out
.
sort
()[
0
].
tolist
()
==
[
0
,
5
,
6
,
7
]
...
...
@@ -36,5 +36,5 @@ def test_random_fps(device):
batch_1
=
torch
.
zeros
(
N
,
dtype
=
torch
.
long
,
device
=
device
)
batch_2
=
torch
.
ones
(
N
,
dtype
=
torch
.
long
,
device
=
device
)
batch
=
torch
.
cat
([
batch_1
,
batch_2
])
idx
=
fps
(
pos
,
batch
,
ratio
=
0.5
)
idx
=
fps
(
pos
,
batch
,
ratio
=
torch
.
tensor
(
0.5
)
)
assert
idx
.
min
()
>=
0
and
idx
.
max
()
<
2
*
N
torch_cluster/fps.py
View file @
82968b99
...
...
@@ -5,7 +5,7 @@ import torch
@
torch
.
jit
.
script
def
fps
(
src
:
torch
.
Tensor
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
ratio
:
float
=
0.5
,
random_start
:
bool
=
True
)
->
torch
.
Tensor
:
ratio
:
torch
.
Tensor
=
torch
.
tensor
(
0.5
)
,
random_start
:
bool
=
True
)
->
torch
.
Tensor
:
r
""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...
...
@@ -17,7 +17,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
ratio (
float
, optional): Sampling ratio. (default: :obj:`0.5`)
ratio (
Tensor
, optional): Sampling ratio. (default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
...
...
@@ -33,6 +33,11 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
index = fps(src, batch, ratio=0.5)
"""
assert
len
(
ratio
.
shape
)
<
2
,
'Invalid ratio'
ratio
=
ratio
.
to
(
src
.
device
)
if
len
(
ratio
.
shape
)
==
1
:
assert
ratio
.
shape
[
0
]
==
int
(
batch
.
max
())
+
1
,
'Mismatched input and ratio numbers'
if
batch
is
not
None
:
assert
src
.
size
(
0
)
==
batch
.
numel
()
batch_size
=
int
(
batch
.
max
())
+
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment