Commit de431201 authored by rusty1s's avatar rusty1s
Browse files

fix nearest bug

parent e1e47f9b
...@@ -15,9 +15,12 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y, ...@@ -15,9 +15,12 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
const int64_t n_x = blockIdx.x; const int64_t n_x = blockIdx.x;
int64_t batch_idx; int64_t batch_idx;
for (int64_t b = 0; b < batch_size; b++) for (int64_t b = 0; b < batch_size; b++) {
if (ptr_x[b] >= n_x and ptr_x[b + 1] < n_x) if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) {
batch_idx = b; batch_idx = b;
continue;
}
}
const int64_t y_start_idx = ptr_y[batch_idx]; const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1]; const int64_t y_end_idx = ptr_y[batch_idx + 1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment