#include #include "atomics.cuh" #include "utils.cuh" #define THREADS 1024 template __global__ void fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg, int64_t *__restrict__ cum_k, int64_t *__restrict__ start, scalar_t *__restrict__ dist, scalar_t *__restrict__ tmp_dist, int64_t *__restrict__ out, size_t dim) { const size_t batch_idx = blockIdx.x; const size_t idx = threadIdx.x; const size_t start_idx = cum_deg[batch_idx]; const size_t end_idx = cum_deg[batch_idx + 1]; __shared__ scalar_t best_dist[THREADS]; __shared__ int64_t best_dist_idx[THREADS]; if (idx == 0) { out[cum_k[batch_idx]] = start_idx + start[batch_idx]; } for (ptrdiff_t m = cum_k[batch_idx] + 1; m < cum_k[batch_idx + 1]; m++) { ptrdiff_t best_idx = 0; scalar_t best = -1; for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) { tmp_dist[n] = 0; } __syncthreads(); for (ptrdiff_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) { scalar_t d = x[(out[m - 1] * dim) + (i % dim)] - x[i]; atomicAdd(&tmp_dist[i / dim], d * d); } __syncthreads(); for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) { dist[n] = min(dist[n], tmp_dist[n]); if (dist[n] > best) { best = dist[n]; best_idx = n; } } best_dist[idx] = best; best_dist_idx[idx] = best_idx; for (int64_t u = 0; (1 << u) < THREADS; u++) { __syncthreads(); if (idx < (THREADS >> (u + 1))) { int64_t idx_1 = (idx * 2) << u; int64_t idx_2 = (idx * 2 + 1) << u; if (best_dist[idx_1] < best_dist[idx_2]) { best_dist[idx_1] = best_dist[idx_2]; best_dist_idx[idx_1] = best_dist_idx[idx_2]; } } } __syncthreads(); out[m] = best_dist_idx[0]; } } at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); cudaMemcpy(batch_sizes, batch[-1].data(), sizeof(int64_t), cudaMemcpyDeviceToHost); auto batch_size = batch_sizes[0] + 1; auto deg = degree(batch, batch_size); auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0); auto k = (deg.toType(at::kFloat) * ratio).round().toType(at::kLong); auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0); at::Tensor start; if (random) { start = at::rand(batch_size, x.options()); start = (start * deg.toType(at::kFloat)).toType(at::kLong); } else { start = at::zeros(batch_size, k.options()); } auto dist = at::full(x.size(0), 1e38, x.options()); auto tmp_dist = at::empty(x.size(0), x.options()); auto k_sum = (int64_t *)malloc(sizeof(int64_t)); cudaMemcpy(k_sum, cum_k[-1].data(), sizeof(int64_t), cudaMemcpyDeviceToHost); auto out = at::empty(k_sum[0], k.options()); AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] { fps_kernel<<>>( x.data(), cum_deg.data(), cum_k.data(), start.data(), dist.data(), tmp_dist.data(), out.data(), x.size(1)); }); return out; }