Commit 7da5a7e6 authored by rusty1s's avatar rusty1s
Browse files

fps in feature space done

parent f4ad453a
...@@ -14,33 +14,57 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg, ...@@ -14,33 +14,57 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
const size_t batch_idx = blockIdx.x; const size_t batch_idx = blockIdx.x;
const size_t idx = threadIdx.x; const size_t idx = threadIdx.x;
const size_t stride = blockDim.x; // == THREADS
const size_t start_idx = cum_deg[batch_idx]; const size_t start_idx = cum_deg[batch_idx];
const size_t end_idx = cum_deg[batch_idx + 1]; const size_t end_idx = cum_deg[batch_idx + 1];
int64_t old = start_idx + start[batch_idx]; __shared__ scalar_t best_dist[THREADS];
__shared__ int64_t best_dist_idx[THREADS];
if (idx == 0) { if (idx == 0) {
out[cum_k[batch_idx]] = old; out[cum_k[batch_idx]] = start_idx + start[batch_idx];
} }
for (ptrdiff_t m = cum_k[batch_idx] + 1; m < cum_k[batch_idx + 1]; m++) { for (ptrdiff_t m = cum_k[batch_idx] + 1; m < cum_k[batch_idx + 1]; m++) {
ptrdiff_t best_idx = 0;
scalar_t best = -1;
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += stride) { for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
tmp_dist[n] = 0; tmp_dist[n] = 0;
} }
__syncthreads(); __syncthreads();
for (ptrdiff_t i = start_idx * dim + idx; i < end_idx * dim; i += stride) { for (ptrdiff_t i = start_idx * dim + idx; i < end_idx * dim; i += THREADS) {
scalar_t d = x[(old * dim) + (i % dim)] - x[i]; scalar_t d = x[(out[m - 1] * dim) + (i % dim)] - x[i];
atomicAdd(&tmp_dist[i / dim], d * d); atomicAdd(&tmp_dist[i / dim], d * d);
} }
__syncthreads(); __syncthreads();
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += stride) { for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
dist[n] = min(dist[n], tmp_dist[n]); dist[n] = min(dist[n], tmp_dist[n]);
if (dist[n] > best) {
best = dist[n];
best_idx = n;
}
} }
best_dist[idx] = best;
best_dist_idx[idx] = best_idx;
for (int64_t u = 0; (1 << u) < THREADS; u++) {
__syncthreads();
if (idx < (THREADS >> (u + 1))) {
int64_t idx_1 = (idx * 2) << u;
int64_t idx_2 = (idx * 2 + 1) << u;
if (best_dist[idx_1] < best_dist[idx_2]) {
best_dist[idx_1] = best_dist[idx_2];
best_dist_idx[idx_1] = best_dist_idx[idx_2];
}
}
}
__syncthreads();
out[m] = best_dist_idx[0];
} }
} }
...@@ -78,103 +102,5 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { ...@@ -78,103 +102,5 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
out.data<int64_t>(), x.size(1)); out.data<int64_t>(), x.size(1));
}); });
return dist; return out;
} }
// at::Tensor ifp_cuda(at::Tensor x, at::Tensor batch, float ratio) {
// AT_DISPATCH_FLOATING_TYPES(x.type(), "ifp_kernel", [&] {
// ifp_kernel<scalar_t><<<BLOCKS(x.numel()), THREADS>>>(
// x.data<scalar_t>(), batch.data<int64_t>(), ratio, x.numel());
// });
// return x;
// }
// __global__ void ifps_kernel() {}
// // x: [N, F]
// // count: [B]
// // batch: [N]
// // tmp min distances: [N]
// // start node idx
// // we parallelize over n times f
// // parallelization over n times f: We can compute distances over atomicAdd
// // each block corresponds to a batch
// __global__ void farthestpointsamplingKernel(int b, int n, int m,
// const float *__restrict__
// dataset, float *__restrict__
// temp, int *__restrict__ idxs) {
// // dataset: [N*3] entries
// // b: batch-size
// // n: number of nodes
// // m: number of sample points
// if (m <= 0)
// return;
// const int BlockSize = 512;
// __shared__ float dists[BlockSize];
// __shared__ int dists_i[BlockSize];
// const int BufferSize = 3072;
// __shared__ float buf[BufferSize * 3];
// for (int i = blockIdx.x; i < b; i += gridDim.x) { // iterate over all
// batches?
// int old = 0;
// if (threadIdx.x == 0)
// idxs[i * m + 0] = old;
// for (int j = threadIdx.x; j < n; j += blockDim.x) { // iterate over all n
// temp[blockIdx.x * n + j] = 1e38;
// }
// for (int j = threadIdx.x; j < min(BufferSize, n) * 3; j += blockDim.x) {
// buf[j] = dataset[i * n * 3 + j];
// }
// __syncthreads();
// for (int j = 1; j < m; j++) {
// int besti = 0;
// float best = -1;
// float x1 = dataset[i * n * 3 + old * 3 + 0];
// float y1 = dataset[i * n * 3 + old * 3 + 1];
// float z1 = dataset[i * n * 3 + old * 3 + 2];
// for (int k = threadIdx.x; k < n; k += blockDim.x) {
// float td = temp[blockIdx.x * n + k];
// float x2, y2, z2;
// if (k < BufferSize) {
// x2 = buf[k * 3 + 0];
// y2 = buf[k * 3 + 1];
// z2 = buf[k * 3 + 2];
// } else {
// x2 = dataset[i * n * 3 + k * 3 + 0];
// y2 = dataset[i * n * 3 + k * 3 + 1];
// z2 = dataset[i * n * 3 + k * 3 + 2];
// }
// float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) +
// (z2 - z1) * (z2 - z1);
// float d2 = min(d, td);
// if (d2 != td)
// temp[blockIdx.x * n + k] = d2;
// if (d2 > best) {
// best = d2;
// besti = k;
// }
// }
// dists[threadIdx.x] = best;
// dists_i[threadIdx.x] = besti;
// for (int u = 0; (1 << u) < blockDim.x; u++) {
// __syncthreads();
// if (threadIdx.x < (blockDim.x >> (u + 1))) {
// int i1 = (threadIdx.x * 2) << u;
// int i2 = (threadIdx.x * 2 + 1) << u;
// if (dists[i1] < dists[i2]) {
// dists[i1] = dists[i2];
// dists_i[i1] = dists_i[i2];
// }
// }
// }
// __syncthreads();
// old = dists_i[0];
// if (threadIdx.x == 0)
// idxs[i * m + j] = old;
// }
// }
// }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment