Commit a5abfee8 authored by rusty1s's avatar rusty1s
Browse files

fix nearest parallel reduction

parent 817b767e
...@@ -18,7 +18,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y, ...@@ -18,7 +18,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
for (int64_t b = 0; b < batch_size; b++) { for (int64_t b = 0; b < batch_size; b++) {
if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) { if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) {
batch_idx = b; batch_idx = b;
continue; break;
} }
} }
...@@ -47,12 +47,15 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y, ...@@ -47,12 +47,15 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
best_dist[thread_idx] = best; best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx; best_dist_idx[thread_idx] = best_idx;
for (int64_t i = 1; i < THREADS; i *= 2) { for (int64_t u = 0; (1 << u) < THREADS; u++) {
__syncthreads(); __syncthreads();
if ((thread_idx + i) < THREADS && if (thread_idx < (THREADS >> (u + 1))) {
best_dist[thread_idx] > best_dist[thread_idx + i]) { int64_t idx_1 = (thread_idx * 2) << u;
best_dist[thread_idx] = best_dist[thread_idx + i]; int64_t idx_2 = (thread_idx * 2 + 1) << u;
best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i]; if (best_dist[idx_1] > best_dist[idx_2]) {
best_dist[idx_1] = best_dist[idx_2];
best_dist_idx[idx_1] = best_dist_idx[idx_2];
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment