Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
f7af865f
Commit
f7af865f
authored
Nov 15, 2018
by
rusty1s
Browse files
typos
parent
5d2168d2
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
29 additions
and
35 deletions
+29
-35
cuda/fps_kernel.cu
cuda/fps_kernel.cu
+6
-6
cuda/nearest.cpp
cuda/nearest.cpp
+4
-5
cuda/nearest_kernel.cu
cuda/nearest_kernel.cu
+12
-16
test/test_nearest.py
test/test_nearest.py
+2
-3
torch_cluster/fps.py
torch_cluster/fps.py
+3
-3
torch_cluster/nearest.py
torch_cluster/nearest.py
+2
-2
No files found.
cuda/fps_kernel.cu
View file @
f7af865f
...
...
@@ -11,7 +11,7 @@ template <typename scalar_t> struct Dist<scalar_t, 1> {
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
...
...
@@ -29,7 +29,7 @@ template <typename scalar_t> struct Dist<scalar_t, 2> {
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
...
...
@@ -49,7 +49,7 @@ template <typename scalar_t> struct Dist<scalar_t, 3> {
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
...
...
@@ -70,7 +70,7 @@ template <typename scalar_t> struct Dist<scalar_t, -1> {
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
...
...
@@ -96,8 +96,8 @@ template <typename scalar_t> struct Dist<scalar_t, -1> {
template
<
typename
scalar_t
,
int64_t
Dim
>
__global__
void
fps_kernel
(
scalar_t
*
__restrict__
x
,
int64_t
*
__restrict__
cum_deg
,
int64_t
*
__restrict__
cum_k
,
int64_t
*
__restrict__
start
,
fps_kernel
(
const
scalar_t
*
__restrict__
x
,
const
int64_t
*
__restrict__
cum_deg
,
const
int64_t
*
__restrict__
cum_k
,
const
int64_t
*
__restrict__
start
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
int64_t
*
__restrict__
out
,
size_t
dim
)
{
...
...
cuda/nearest.cpp
View file @
f7af865f
...
...
@@ -3,12 +3,11 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
nearest_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
at
::
Tensor
batch_x
,
at
::
Tensor
nearest_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
nearest
(
at
::
Tensor
x
,
at
::
Tensor
y
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
)
{
at
::
Tensor
nearest
(
at
::
Tensor
x
,
at
::
Tensor
y
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
)
{
CHECK_CUDA
(
x
);
IS_CONTIGUOUS
(
x
);
CHECK_CUDA
(
y
);
...
...
cuda/nearest_kernel.cu
View file @
f7af865f
...
...
@@ -5,11 +5,11 @@
#define THREADS 1024
template
<
typename
scalar_t
>
__global__
void
nearest_kernel
(
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
y
,
int64_t
*
__restrict__
batch_x
,
int64_t
*
__restrict__
batch_
y
,
scalar_t
*
__restrict__
out
,
int64_t
*
__restrict__
out_idx
,
size_t
dim
)
{
__global__
void
nearest_kernel
(
const
scalar_t
*
__restrict__
x
,
const
scalar_t
*
__restrict__
y
,
const
int64_t
*
__restrict__
batch_
x
,
const
int64_t
*
__restrict__
batch_y
,
int64_t
*
__restrict__
out
,
const
size_t
dim
)
{
const
ptrdiff_t
n_x
=
blockIdx
.
x
;
const
ptrdiff_t
batch_idx
=
batch_x
[
n_x
];
...
...
@@ -55,13 +55,11 @@ nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y,
__syncthreads
();
if
(
idx
==
0
)
{
out
[
n_x
]
=
best_dist
[
0
];
out_idx
[
n_x
]
=
best_dist_idx
[
0
];
out
[
n_x
]
=
best_dist_idx
[
0
];
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
nearest_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
at
::
Tensor
batch_x
,
at
::
Tensor
nearest_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
)
{
auto
batch_sizes
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
batch_sizes
,
batch_x
[
-
1
].
data
<
int64_t
>
(),
sizeof
(
int64_t
),
...
...
@@ -71,15 +69,13 @@ std::tuple<at::Tensor, at::Tensor> nearest_cuda(at::Tensor x, at::Tensor y,
batch_y
=
degree
(
batch_y
,
batch_size
);
batch_y
=
at
::
cat
({
at
::
zeros
(
1
,
batch_y
.
options
()),
batch_y
.
cumsum
(
0
)},
0
);
auto
out
=
at
::
empty
(
x
.
size
(
0
),
x
.
options
());
auto
out_idx
=
at
::
empty_like
(
batch_x
);
auto
out
=
at
::
empty_like
(
batch_x
);
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"
fps
_kernel"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"
nearest
_kernel"
,
[
&
]
{
nearest_kernel
<
scalar_t
><<<
x
.
size
(
0
),
THREADS
>>>
(
x
.
data
<
scalar_t
>
(),
y
.
data
<
scalar_t
>
(),
batch_x
.
data
<
int64_t
>
(),
batch_y
.
data
<
int64_t
>
(),
out
.
data
<
scalar_t
>
(),
out_idx
.
data
<
int64_t
>
(),
x
.
size
(
1
));
batch_y
.
data
<
int64_t
>
(),
out
.
data
<
int64_t
>
(),
x
.
size
(
1
));
});
return
std
::
make_tuple
(
out
,
out_idx
)
;
return
out
;
}
test/test_nearest.py
View file @
f7af865f
...
...
@@ -32,6 +32,5 @@ def test_nearest(dtype, device):
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
0
,
1
,
1
],
torch
.
long
,
device
)
dist
,
idx
=
nearest
(
x
,
y
,
batch_x
,
batch_y
)
assert
dist
.
tolist
()
==
[
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
]
assert
idx
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
out
=
nearest
(
x
,
y
,
batch_x
,
batch_y
)
assert
out
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
torch_cluster/fps.py
View file @
f7af865f
...
...
@@ -29,13 +29,13 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
if
batch
is
None
:
batch
=
x
.
new_zeros
(
x
.
size
(
0
),
dtype
=
torch
.
long
)
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
assert
x
.
is_cuda
assert
x
.
dim
()
<
=
2
and
batch
.
dim
()
==
1
assert
x
.
dim
()
=
=
2
and
batch
.
dim
()
==
1
assert
x
.
size
(
0
)
==
batch
.
size
(
0
)
assert
ratio
>
0
and
ratio
<
1
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
op
=
fps_cuda
.
fps
if
x
.
is_cuda
else
None
out
=
op
(
x
,
batch
,
ratio
,
random_start
)
...
...
torch_cluster/nearest.py
View file @
f7af865f
...
...
@@ -22,6 +22,6 @@ def nearest(x, y, batch_x=None, batch_y=None):
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
op
=
nearest_cuda
.
nearest
if
x
.
is_cuda
else
None
dist
,
idx
=
op
(
x
,
y
,
batch_x
,
batch_y
)
out
=
op
(
x
,
y
,
batch_x
,
batch_y
)
return
dist
,
idx
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment