Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
a5abfee8
Commit
a5abfee8
authored
Dec 07, 2020
by
rusty1s
Browse files
fix nearest parallel reduction
parent
817b767e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
csrc/cuda/nearest_cuda.cu
csrc/cuda/nearest_cuda.cu
+9
-6
No files found.
csrc/cuda/nearest_cuda.cu
View file @
a5abfee8
...
@@ -18,7 +18,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
...
@@ -18,7 +18,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
for
(
int64_t
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int64_t
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
n_x
>=
ptr_x
[
b
]
&&
n_x
<
ptr_x
[
b
+
1
])
{
if
(
n_x
>=
ptr_x
[
b
]
&&
n_x
<
ptr_x
[
b
+
1
])
{
batch_idx
=
b
;
batch_idx
=
b
;
continue
;
break
;
}
}
}
}
...
@@ -47,12 +47,15 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
...
@@ -47,12 +47,15 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
best_dist
[
thread_idx
]
=
best
;
best_dist
[
thread_idx
]
=
best
;
best_dist_idx
[
thread_idx
]
=
best_idx
;
best_dist_idx
[
thread_idx
]
=
best_idx
;
for
(
int64_t
i
=
1
;
i
<
THREADS
;
i
*=
2
)
{
for
(
int64_t
u
=
0
;
(
1
<<
u
)
<
THREADS
;
u
++
)
{
__syncthreads
();
__syncthreads
();
if
((
thread_idx
+
i
)
<
THREADS
&&
if
(
thread_idx
<
(
THREADS
>>
(
u
+
1
)))
{
best_dist
[
thread_idx
]
>
best_dist
[
thread_idx
+
i
])
{
int64_t
idx_1
=
(
thread_idx
*
2
)
<<
u
;
best_dist
[
thread_idx
]
=
best_dist
[
thread_idx
+
i
];
int64_t
idx_2
=
(
thread_idx
*
2
+
1
)
<<
u
;
best_dist_idx
[
thread_idx
]
=
best_dist_idx
[
thread_idx
+
i
];
if
(
best_dist
[
idx_1
]
>
best_dist
[
idx_2
])
{
best_dist
[
idx_1
]
=
best_dist
[
idx_2
];
best_dist_idx
[
idx_1
]
=
best_dist_idx
[
idx_2
];
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment