Commit e83a45a3 authored by rusty1s's avatar rusty1s
Browse files

nearest fixes

parent 713fb60a
......@@ -8,12 +8,14 @@
template <typename scalar_t>
__global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *_ptr_x, const int64_t *_ptr_y,
const int64_t *ptr_x, const int64_t *ptr_y,
int64_t *out, int64_t batch_size, int64_t dim) {
const int64_t thread_idx = threadIdx.x;
const int64_t n_x = blockIdx.x;
int64_t batch_idx;
for (int64_t b = 0; b < batch_idx; b++)
for (int64_t b = 0; b < ptr_x.size(0) - 1; b++)
if (ptr_x[b] >= n_x and ptr_x[b + 1] < n_x)
batch_idx = b;
......@@ -25,7 +27,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
scalar_t best = 1e38;
int64_t best_idx = 0;
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
for (int64_t n_y = y_start_idx + thread_idx; n_y < y_end_idx;
n_y += THREADS) {
scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) {
......@@ -39,14 +41,14 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
}
}
best_dist[idx] = best;
best_dist_idx[idx] = best_idx;
best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx;
for (int64_t u = 0; (1 << u) < THREADS; u++) {
__syncthreads();
if (idx < (THREADS >> (u + 1))) {
int64_t idx_1 = (idx * 2) << u;
int64_t idx_2 = (idx * 2 + 1) << u;
int64_t idx_1 = (thread_idx * 2) << u;
int64_t idx_2 = (thread_idx * 2 + 1) << u;
if (best_dist[idx_1] > best_dist[idx_2]) {
best_dist[idx_1] = best_dist[idx_2];
best_dist_idx[idx_1] = best_dist_idx[idx_2];
......@@ -55,7 +57,7 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
}
__syncthreads();
if (idx == 0) {
if (thread_idx == 0) {
out[n_x] = best_dist_idx[0];
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment