#include "nearest_cuda.h" #include #include "utils.cuh" #define THREADS 1024 template __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y, const int64_t *ptr_x, const int64_t *ptr_y, int64_t *out, int64_t batch_size, int64_t dim) { const int64_t thread_idx = threadIdx.x; const int64_t n_x = blockIdx.x; int64_t batch_idx; for (int64_t b = 0; b < batch_size; b++) { if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) { batch_idx = b; continue; } } const int64_t y_start_idx = ptr_y[batch_idx]; const int64_t y_end_idx = ptr_y[batch_idx + 1]; __shared__ scalar_t best_dist[THREADS]; __shared__ int64_t best_dist_idx[THREADS]; scalar_t best = 1e38; int64_t best_idx = 0; for (int64_t n_y = y_start_idx + thread_idx; n_y < y_end_idx; n_y += THREADS) { scalar_t dist = 0; for (int64_t d = 0; d < dim; d++) { dist += (x[n_x * dim + d] - y[n_y * dim + d]) * (x[n_x * dim + d] - y[n_y * dim + d]); } if (dist < best) { best = dist; best_idx = n_y; } } best_dist[thread_idx] = best; best_dist_idx[thread_idx] = best_idx; for (int64_t i = 1; i < THREADS; i *= 2) { __syncthreads(); if ((thread_idx + i) < THREADS && best_dist[thread_idx] > best_dist[thread_idx + i]) { best_dist[thread_idx] = best_dist[thread_idx + i]; best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i]; } } __syncthreads(); if (thread_idx == 0) { out[n_x] = best_dist_idx[0]; } } torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, torch::Tensor ptr_y) { CHECK_CUDA(x); CHECK_CUDA(y); CHECK_CUDA(ptr_x); CHECK_CUDA(ptr_y); cudaSetDevice(x.get_device()); x = x.view({x.size(0), -1}).contiguous(); y = y.view({y.size(0), -1}).contiguous(); auto out = torch::empty({x.size(0)}, ptr_x.options()); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] { nearest_kernel<<>>( x.data_ptr(), y.data_ptr(), ptr_x.data_ptr(), ptr_y.data_ptr(), out.data_ptr(), ptr_x.size(0) - 1, x.size(1)); }); return out; }