Commit e83a45a3 authored by rusty1s's avatar rusty1s
Browse files

nearest fixes

parent 713fb60a
...@@ -8,12 +8,14 @@ ...@@ -8,12 +8,14 @@
template <typename scalar_t> template <typename scalar_t>
__global__ void nearest_kernel(const scalar_t *x, const scalar_t *y, __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *_ptr_x, const int64_t *_ptr_y, const int64_t *ptr_x, const int64_t *ptr_y,
int64_t *out, int64_t batch_size, int64_t dim) { int64_t *out, int64_t batch_size, int64_t dim) {
const int64_t thread_idx = threadIdx.x;
const int64_t n_x = blockIdx.x; const int64_t n_x = blockIdx.x;
int64_t batch_idx; int64_t batch_idx;
for (int64_t b = 0; b < batch_idx; b++) for (int64_t b = 0; b < ptr_x.size(0) - 1; b++)
if (ptr_x[b] >= n_x and ptr_x[b + 1] < n_x) if (ptr_x[b] >= n_x and ptr_x[b + 1] < n_x)
batch_idx = b; batch_idx = b;
...@@ -25,7 +27,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y, ...@@ -25,7 +27,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
scalar_t best = 1e38; scalar_t best = 1e38;
int64_t best_idx = 0; int64_t best_idx = 0;
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx; for (int64_t n_y = y_start_idx + thread_idx; n_y < y_end_idx;
n_y += THREADS) { n_y += THREADS) {
scalar_t dist = 0; scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) { for (int64_t d = 0; d < dim; d++) {
...@@ -39,14 +41,14 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y, ...@@ -39,14 +41,14 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
} }
} }
best_dist[idx] = best; best_dist[thread_idx] = best;
best_dist_idx[idx] = best_idx; best_dist_idx[thread_idx] = best_idx;
for (int64_t u = 0; (1 << u) < THREADS; u++) { for (int64_t u = 0; (1 << u) < THREADS; u++) {
__syncthreads(); __syncthreads();
if (idx < (THREADS >> (u + 1))) { if (idx < (THREADS >> (u + 1))) {
int64_t idx_1 = (idx * 2) << u; int64_t idx_1 = (thread_idx * 2) << u;
int64_t idx_2 = (idx * 2 + 1) << u; int64_t idx_2 = (thread_idx * 2 + 1) << u;
if (best_dist[idx_1] > best_dist[idx_2]) { if (best_dist[idx_1] > best_dist[idx_2]) {
best_dist[idx_1] = best_dist[idx_2]; best_dist[idx_1] = best_dist[idx_2];
best_dist_idx[idx_1] = best_dist_idx[idx_2]; best_dist_idx[idx_1] = best_dist_idx[idx_2];
...@@ -55,7 +57,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y, ...@@ -55,7 +57,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
} }
__syncthreads(); __syncthreads();
if (idx == 0) { if (thread_idx == 0) {
out[n_x] = best_dist_idx[0]; out[n_x] = best_dist_idx[0];
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment