Commit 0d735d7e authored by rusty1s's avatar rusty1s
Browse files

improve radius performance

parent 0adaf7f9
......@@ -4,4 +4,5 @@
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
......@@ -27,33 +27,28 @@ template <typename scalar_t> struct Cosine {
}
};
__device__ int64_t get_example_idx(int64_t idx, const int64_t *ptr,
const int64_t num_examples) {
for (int64_t i = 0; i < num_examples; i++) {
if (ptr[i + 1] > idx)
return i;
}
return num_examples - 1;
}
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
scalar_t *__restrict__ dist, int64_t *__restrict__ row,
int64_t *__restrict__ col, const int64_t k, const int64_t n,
const int64_t m, const int64_t dim, const int64_t num_examples,
const bool cosine) {
int64_t *__restrict__ row, int64_t *__restrict__ col,
const int64_t k, const int64_t n, const int64_t m, const int64_t dim,
const int64_t num_examples, const bool cosine) {
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m)
return;
for (int64_t e = 0; e < k; e++)
row[n_y * k + e] = n_y;
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
scalar_t best_dist[100];
int64_t best_idx[100];
for (int e = 0; e < k; e++) {
best_dist[e] = 1e10;
best_idx[e] = -1;
}
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
scalar_t tmp_dist = 0;
......@@ -70,17 +65,22 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
}
for (int64_t e1 = 0; e1 < k; e1++) {
if (dist[n_y * k + e1] > tmp_dist) {
if (best_dist[e1] > tmp_dist) {
for (int64_t e2 = k - 1; e2 > e1; e2--) {
dist[n_y * k + e2] = dist[n_y * k + e2 - 1];
col[n_y * k + e2] = col[n_y * k + e2 - 1];
best_dist[e2] = best_dist[e2 - 1];
best_idx[e2] = best_idx[e2 - 1];
}
dist[n_y * k + e1] = tmp_dist;
col[n_y * k + e1] = n_x;
best_dist[e1] = tmp_dist;
best_idx[e1] = n_x;
break;
}
}
}
for (int64_t e = 0; e < k; e++) {
row[n_y * k + e] = n_y;
col[n_y * k + e] = best_idx[e];
}
}
torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
......@@ -89,10 +89,13 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
const bool cosine) {
CHECK_CUDA(x);
CHECK_CONTIGUOUS(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CUDA(y);
CHECK_CONTIGUOUS(y);
CHECK_INPUT(y.dim() == 2);
CHECK_INPUT(x.size(1) == y.size(1));
AT_ASSERTM(k <= 100, "`k` needs to smaller than or equal to 100");
if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value());
......@@ -112,7 +115,6 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
cudaSetDevice(x.get_device());
auto dist = torch::full(y.size(0) * k, 1e10, y.options());
auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
......@@ -123,9 +125,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
col.data_ptr<int64_t>(), k, x.size(0), y.size(0), x.size(1),
ptr_x.value().numel() - 1, cosine);
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), k, x.size(0),
y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine);
});
auto mask = col != -1;
......
......@@ -4,84 +4,88 @@
#include "utils.cuh"
#define THREADS 1024
#define THREADS 256
template <typename scalar_t>
__global__ void radius_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *ptr_x, const int64_t *ptr_y,
int64_t *row, int64_t *col, scalar_t radius,
int64_t max_num_neighbors, int64_t dim) {
__global__ void
radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ ptr_x,
const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row,
int64_t *__restrict__ col, const scalar_t r, const int64_t n,
const int64_t m, const int64_t dim, const int64_t num_examples,
const int64_t max_num_neighbors) {
const int64_t batch_idx = blockIdx.x;
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m)
return;
const int64_t x_start_idx = ptr_x[batch_idx];
const int64_t x_end_idx = ptr_x[batch_idx + 1];
const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1];
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
n_y += THREADS) {
int64_t count = 0;
for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
dist = sqrt(dist);
if (dist < radius) {
if (dist < r) {
row[n_y * max_num_neighbors + count] = n_y;
col[n_y * max_num_neighbors + count] = n_x;
count++;
}
if (count >= max_num_neighbors) {
if (count >= max_num_neighbors)
break;
}
}
}
}
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r,
int64_t max_num_neighbors) {
torch::optional<torch::Tensor> ptr_y, const double r,
const int64_t max_num_neighbors) {
CHECK_CUDA(x);
CHECK_CONTIGUOUS(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CUDA(y);
CHECK_CONTIGUOUS(y);
CHECK_INPUT(y.dim() == 2);
CHECK_INPUT(x.size(1) == y.size(1));
cudaSetDevice(x.get_device());
if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
} else {
} else
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
x.options().dtype(torch::kLong));
}
if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
} else {
} else
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
y.options().dtype(torch::kLong));
}
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
cudaSetDevice(x.get_device());
auto row =
torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
auto col =
torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors,
x.size(1));
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r * r, x.size(0),
y.size(0), x.size(1), ptr_x.value().numel() - 1, max_num_neighbors);
});
auto mask = row != -1;
......
......@@ -5,3 +5,14 @@
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
__device__ int64_t get_example_idx(int64_t idx, const int64_t *ptr,
const int64_t num_examples) {
for (int64_t i = 0; i < num_examples; i++) {
if (ptr[i + 1] > idx)
return i;
}
return num_examples - 1;
}
......@@ -71,7 +71,7 @@ def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device)
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
max_num_neighbors=2000, num_workers=6)
max_num_neighbors=2000)
tree = scipy.spatial.cKDTree(x.cpu().numpy())
col = tree.query_ball_point(x.cpu(), r=0.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment