Commit 0d735d7e authored by rusty1s's avatar rusty1s
Browse files

improve radius performance

parent 0adaf7f9
...@@ -4,4 +4,5 @@ ...@@ -4,4 +4,5 @@
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
...@@ -27,33 +27,28 @@ template <typename scalar_t> struct Cosine { ...@@ -27,33 +27,28 @@ template <typename scalar_t> struct Cosine {
} }
}; };
__device__ int64_t get_example_idx(int64_t idx, const int64_t *ptr,
const int64_t num_examples) {
for (int64_t i = 0; i < num_examples; i++) {
if (ptr[i + 1] > idx)
return i;
}
return num_examples - 1;
}
template <typename scalar_t> template <typename scalar_t>
__global__ void __global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y, const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
scalar_t *__restrict__ dist, int64_t *__restrict__ row, int64_t *__restrict__ row, int64_t *__restrict__ col,
int64_t *__restrict__ col, const int64_t k, const int64_t n, const int64_t k, const int64_t n, const int64_t m, const int64_t dim,
const int64_t m, const int64_t dim, const int64_t num_examples, const int64_t num_examples, const bool cosine) {
const bool cosine) {
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x; const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m) if (n_y >= m)
return; return;
for (int64_t e = 0; e < k; e++)
row[n_y * k + e] = n_y;
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples); const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
scalar_t best_dist[100];
int64_t best_idx[100];
for (int e = 0; e < k; e++) {
best_dist[e] = 1e10;
best_idx[e] = -1;
}
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) { for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
scalar_t tmp_dist = 0; scalar_t tmp_dist = 0;
...@@ -70,17 +65,22 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, ...@@ -70,17 +65,22 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
} }
for (int64_t e1 = 0; e1 < k; e1++) { for (int64_t e1 = 0; e1 < k; e1++) {
if (dist[n_y * k + e1] > tmp_dist) { if (best_dist[e1] > tmp_dist) {
for (int64_t e2 = k - 1; e2 > e1; e2--) { for (int64_t e2 = k - 1; e2 > e1; e2--) {
dist[n_y * k + e2] = dist[n_y * k + e2 - 1]; best_dist[e2] = best_dist[e2 - 1];
col[n_y * k + e2] = col[n_y * k + e2 - 1]; best_idx[e2] = best_idx[e2 - 1];
} }
dist[n_y * k + e1] = tmp_dist; best_dist[e1] = tmp_dist;
col[n_y * k + e1] = n_x; best_idx[e1] = n_x;
break; break;
} }
} }
} }
for (int64_t e = 0; e < k; e++) {
row[n_y * k + e] = n_y;
col[n_y * k + e] = best_idx[e];
}
} }
torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
...@@ -89,10 +89,13 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, ...@@ -89,10 +89,13 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
const bool cosine) { const bool cosine) {
CHECK_CUDA(x); CHECK_CUDA(x);
CHECK_CONTIGUOUS(x);
CHECK_INPUT(x.dim() == 2); CHECK_INPUT(x.dim() == 2);
CHECK_CUDA(y); CHECK_CUDA(y);
CHECK_CONTIGUOUS(y);
CHECK_INPUT(y.dim() == 2); CHECK_INPUT(y.dim() == 2);
CHECK_INPUT(x.size(1) == y.size(1)); CHECK_INPUT(x.size(1) == y.size(1));
AT_ASSERTM(k <= 100, "`k` needs to smaller than or equal to 100");
if (ptr_x.has_value()) { if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value()); CHECK_CUDA(ptr_x.value());
...@@ -112,7 +115,6 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, ...@@ -112,7 +115,6 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
cudaSetDevice(x.get_device()); cudaSetDevice(x.get_device());
auto dist = torch::full(y.size(0) * k, 1e10, y.options());
auto row = torch::empty(y.size(0) * k, ptr_y.value().options()); auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options()); auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
...@@ -123,9 +125,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, ...@@ -123,9 +125,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>( knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(), ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(), row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), k, x.size(0),
col.data_ptr<int64_t>(), k, x.size(0), y.size(0), x.size(1), y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine);
ptr_x.value().numel() - 1, cosine);
}); });
auto mask = col != -1; auto mask = col != -1;
......
...@@ -4,84 +4,88 @@ ...@@ -4,84 +4,88 @@
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 256
template <typename scalar_t> template <typename scalar_t>
__global__ void radius_kernel(const scalar_t *x, const scalar_t *y, __global__ void
const int64_t *ptr_x, const int64_t *ptr_y, radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
int64_t *row, int64_t *col, scalar_t radius, const int64_t *__restrict__ ptr_x,
int64_t max_num_neighbors, int64_t dim) { const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row,
int64_t *__restrict__ col, const scalar_t r, const int64_t n,
const int64_t m, const int64_t dim, const int64_t num_examples,
const int64_t max_num_neighbors) {
const int64_t batch_idx = blockIdx.x; const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m)
return;
const int64_t x_start_idx = ptr_x[batch_idx];
const int64_t x_end_idx = ptr_x[batch_idx + 1];
const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1];
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
n_y += THREADS) {
int64_t count = 0; int64_t count = 0;
for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) { const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
scalar_t dist = 0; scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) { for (int64_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) * dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]); (x[n_x * dim + d] - y[n_y * dim + d]);
} }
dist = sqrt(dist);
if (dist < radius) { if (dist < r) {
row[n_y * max_num_neighbors + count] = n_y; row[n_y * max_num_neighbors + count] = n_y;
col[n_y * max_num_neighbors + count] = n_x; col[n_y * max_num_neighbors + count] = n_x;
count++; count++;
} }
if (count >= max_num_neighbors) { if (count >= max_num_neighbors)
break; break;
} }
}
}
} }
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
torch::optional<torch::Tensor> ptr_x, torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, double r, torch::optional<torch::Tensor> ptr_y, const double r,
int64_t max_num_neighbors) { const int64_t max_num_neighbors) {
CHECK_CUDA(x); CHECK_CUDA(x);
CHECK_CONTIGUOUS(x);
CHECK_INPUT(x.dim() == 2); CHECK_INPUT(x.dim() == 2);
CHECK_CUDA(y); CHECK_CUDA(y);
CHECK_CONTIGUOUS(y);
CHECK_INPUT(y.dim() == 2); CHECK_INPUT(y.dim() == 2);
CHECK_INPUT(x.size(1) == y.size(1));
cudaSetDevice(x.get_device()); cudaSetDevice(x.get_device());
if (ptr_x.has_value()) { if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value()); CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1); CHECK_INPUT(ptr_x.value().dim() == 1);
} else { } else
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0), ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
x.options().dtype(torch::kLong)); x.options().dtype(torch::kLong));
}
if (ptr_y.has_value()) { if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value()); CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1); CHECK_INPUT(ptr_y.value().dim() == 1);
} else { } else
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0), ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
y.options().dtype(torch::kLong)); y.options().dtype(torch::kLong));
}
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel()); CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
cudaSetDevice(x.get_device());
auto row = auto row =
torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options()); torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
auto col = auto col =
torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options()); torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>( radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(), ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors, row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r * r, x.size(0),
x.size(1)); y.size(0), x.size(1), ptr_x.value().numel() - 1, max_num_neighbors);
}); });
auto mask = row != -1; auto mask = row != -1;
......
...@@ -5,3 +5,14 @@ ...@@ -5,3 +5,14 @@
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
__device__ int64_t get_example_idx(int64_t idx, const int64_t *ptr,
const int64_t num_examples) {
for (int64_t i = 0; i < num_examples; i++) {
if (ptr[i + 1] > idx)
return i;
}
return num_examples - 1;
}
...@@ -71,7 +71,7 @@ def test_radius_graph_large(dtype, device): ...@@ -71,7 +71,7 @@ def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device) x = torch.randn(1000, 3, dtype=dtype, device=device)
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True, edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
max_num_neighbors=2000, num_workers=6) max_num_neighbors=2000)
tree = scipy.spatial.cKDTree(x.cpu().numpy()) tree = scipy.spatial.cKDTree(x.cpu().numpy())
col = tree.query_ball_point(x.cpu(), r=0.5) col = tree.query_ball_point(x.cpu(), r=0.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment