Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
f8c1aefa
Commit
f8c1aefa
authored
Nov 14, 2018
by
rusty1s
Browse files
nearest neighbor implementation
parent
080ad88b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
1 deletion
+73
-1
cuda/nearest_kernel.cu
cuda/nearest_kernel.cu
+73
-1
No files found.
cuda/nearest_kernel.cu
View file @
f8c1aefa
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include "utils.cuh"
#define THREADS 1024
#define THREADS 1024
template
<
typename
scalar_t
,
int64_t
Dim
>
__global__
void
nearest_kernel
(
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
y
,
int64_t
*
__restrict__
batch_x
,
int64_t
*
__restrict__
batch_y
,
int64_t
*
__restrict__
out
,
size_t
dim
)
{
const
ptrdiff_t
n_x
=
blockIdx
.
x
;
const
ptrdiff_t
batch_idx
=
batch_x
[
n_x
];
const
ptrdiff_t
idx
=
threadIdx
.
x
;
const
ptrdiff_t
start_idx
=
batch_y
[
batch_idx
];
const
ptrdiff_t
end_idx
=
batch_y
[
batch_idx
+
1
];
__shared__
scalar_t
best_dist
[
THREADS
];
__shared__
int64_t
best_dist_idx
[
THREADS
];
scalar_t
best
=
1e38
;
ptrdiff_t
best_idx
=
0
;
for
(
ptrdiff_t
n_y
=
start_idx
+
idx
;
n_y
<
end_idx
;
n_y
+=
THREADS
)
{
scalar_t
dist
=
0
;
for
(
ptrdiff_t
d
=
0
;
d
<
dim
;
d
++
)
{
dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
]);
}
if
(
dist
<
best
)
{
best
=
dist
;
best_idx
=
n_y
;
}
}
best_dist
[
idx
]
=
best
;
best_dist_idx
[
idx
]
=
best_idx
;
for
(
int64_t
u
=
0
;
(
1
<<
u
)
<
THREADS
;
u
++
)
{
__syncthreads
();
if
(
idx
<
(
THREADS
>>
(
u
+
1
)))
{
int64_t
idx_1
=
(
idx
*
2
)
<<
u
;
int64_t
idx_2
=
(
idx
*
2
+
1
)
<<
u
;
if
(
best_dist
[
idx_1
]
>
best_dist
[
idx_2
])
{
best_dist
[
idx_1
]
=
best_dist
[
idx_2
];
best_dist_idx
[
idx_1
]
=
best_dist_idx
[
idx_2
];
}
}
}
__syncthreads
();
if
(
idx
==
0
)
{
out
[
n_x
]
=
best_dist_idx
[
0
];
}
}
at
::
Tensor
nearest_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
at
::
Tensor
batch_x
,
at
::
Tensor
nearest_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
)
{
at
::
Tensor
batch_y
)
{
return
batch_x
;
auto
batch_sizes
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
batch_sizes
,
batch_x
[
-
1
].
data
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
batch_size
=
batch_sizes
[
0
]
+
1
;
batch_y
=
degree
(
batch_y
,
batch_size
);
batch_y
=
at
::
cat
({
at
::
zeros
(
1
,
batch_y
.
options
()),
batch_y
.
cumsum
(
0
)},
0
);
auto
out
=
at
::
empty_like
(
batch_x
);
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"fps_kernel"
,
[
&
]
{
nearest_kernel
<
scalar_t
,
-
1
><<<
x
.
size
(
0
),
THREADS
>>>
(
x
.
data
<
scalar_t
>
(),
y
.
data
<
scalar_t
>
(),
batch_x
.
data
<
int64_t
>
(),
batch_y
.
data
<
int64_t
>
(),
out
.
data
<
int64_t
>
(),
x
.
size
(
1
));
});
return
out
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment