Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
e83a45a3
Commit
e83a45a3
authored
Mar 12, 2020
by
rusty1s
Browse files
nearest fixes
parent
713fb60a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
8 deletions
+10
-8
csrc/cuda/nearest_cuda.cu
csrc/cuda/nearest_cuda.cu
+10
-8
No files found.
csrc/cuda/nearest_cuda.cu
View file @
e83a45a3
...
...
@@ -8,12 +8,14 @@
template
<
typename
scalar_t
>
__global__
void
nearest_kernel
(
const
scalar_t
*
x
,
const
scalar_t
*
y
,
const
int64_t
*
_
ptr_x
,
const
int64_t
*
_
ptr_y
,
const
int64_t
*
ptr_x
,
const
int64_t
*
ptr_y
,
int64_t
*
out
,
int64_t
batch_size
,
int64_t
dim
)
{
const
int64_t
thread_idx
=
threadIdx
.
x
;
const
int64_t
n_x
=
blockIdx
.
x
;
int64_t
batch_idx
;
for
(
int64_t
b
=
0
;
b
<
batch_idx
;
b
++
)
for
(
int64_t
b
=
0
;
b
<
ptr_x
.
size
(
0
)
-
1
;
b
++
)
if
(
ptr_x
[
b
]
>=
n_x
and
ptr_x
[
b
+
1
]
<
n_x
)
batch_idx
=
b
;
...
...
@@ -25,7 +27,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
scalar_t
best
=
1e38
;
int64_t
best_idx
=
0
;
for
(
int64_t
n_y
=
y_start_idx
+
thread
Idx
.
x
;
n_y
<
y_end_idx
;
for
(
int64_t
n_y
=
y_start_idx
+
thread
_id
x
;
n_y
<
y_end_idx
;
n_y
+=
THREADS
)
{
scalar_t
dist
=
0
;
for
(
int64_t
d
=
0
;
d
<
dim
;
d
++
)
{
...
...
@@ -39,14 +41,14 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
}
}
best_dist
[
idx
]
=
best
;
best_dist_idx
[
idx
]
=
best_idx
;
best_dist
[
thread_
idx
]
=
best
;
best_dist_idx
[
thread_
idx
]
=
best_idx
;
for
(
int64_t
u
=
0
;
(
1
<<
u
)
<
THREADS
;
u
++
)
{
__syncthreads
();
if
(
idx
<
(
THREADS
>>
(
u
+
1
)))
{
int64_t
idx_1
=
(
idx
*
2
)
<<
u
;
int64_t
idx_2
=
(
idx
*
2
+
1
)
<<
u
;
int64_t
idx_1
=
(
thread_
idx
*
2
)
<<
u
;
int64_t
idx_2
=
(
thread_
idx
*
2
+
1
)
<<
u
;
if
(
best_dist
[
idx_1
]
>
best_dist
[
idx_2
])
{
best_dist
[
idx_1
]
=
best_dist
[
idx_2
];
best_dist_idx
[
idx_1
]
=
best_dist_idx
[
idx_2
];
...
...
@@ -55,7 +57,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
}
__syncthreads
();
if
(
idx
==
0
)
{
if
(
thread_
idx
==
0
)
{
out
[
n_x
]
=
best_dist_idx
[
0
];
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment