Commit a5abfee8 authored by rusty1s's avatar rusty1s
Browse files

fix nearest parallel reduction

parent 817b767e
......@@ -18,7 +18,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
for (int64_t b = 0; b < batch_size; b++) {
if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) {
batch_idx = b;
continue;
break;
}
}
......@@ -47,12 +47,15 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx;
for (int64_t i = 1; i < THREADS; i *= 2) {
for (int64_t u = 0; (1 << u) < THREADS; u++) {
__syncthreads();
if ((thread_idx + i) < THREADS &&
best_dist[thread_idx] > best_dist[thread_idx + i]) {
best_dist[thread_idx] = best_dist[thread_idx + i];
best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i];
if (thread_idx < (THREADS >> (u + 1))) {
int64_t idx_1 = (thread_idx * 2) << u;
int64_t idx_2 = (thread_idx * 2 + 1) << u;
if (best_dist[idx_1] > best_dist[idx_2]) {
best_dist[idx_1] = best_dist[idx_2];
best_dist_idx[idx_1] = best_dist_idx[idx_2];
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment