Commit 3df5f8f5 authored by rusty1s's avatar rusty1s
Browse files

ifps copy

parent 8bbe7950
__global__ void ifps_kernel() {}
// x: [N, F]
// count: [B]
// batch: [N]
// tmp min distances: [N]
// start node idx
// we parallelize over n times f
// parallelization over n times f: We can compute distances over atomicAdd
// each block corresponds to a batch
__global__ void farthestpointsamplingKernel(int b, int n, int m,
const float *__restrict__ dataset,
float *__restrict__ temp,
int *__restrict__ idxs) {
// dataset: [N*3] entries
// b: batch-size
// n: number of nodes
// m: number of sample points
if (m <= 0)
return;
const int BlockSize = 512;
__shared__ float dists[BlockSize];
__shared__ int dists_i[BlockSize];
const int BufferSize = 3072;
__shared__ float buf[BufferSize * 3];
for (int i = blockIdx.x; i < b; i += gridDim.x) { // iterate over all batches?
int old = 0;
if (threadIdx.x == 0)
idxs[i * m + 0] = old;
for (int j = threadIdx.x; j < n; j += blockDim.x) { // iterate over all n
temp[blockIdx.x * n + j] = 1e38;
}
for (int j = threadIdx.x; j < min(BufferSize, n) * 3; j += blockDim.x) {
buf[j] = dataset[i * n * 3 + j];
}
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[i * n * 3 + old * 3 + 0];
float y1 = dataset[i * n * 3 + old * 3 + 1];
float z1 = dataset[i * n * 3 + old * 3 + 2];
for (int k = threadIdx.x; k < n; k += blockDim.x) {
float td = temp[blockIdx.x * n + k];
float x2, y2, z2;
if (k < BufferSize) {
x2 = buf[k * 3 + 0];
y2 = buf[k * 3 + 1];
z2 = buf[k * 3 + 2];
} else {
x2 = dataset[i * n * 3 + k * 3 + 0];
y2 = dataset[i * n * 3 + k * 3 + 1];
z2 = dataset[i * n * 3 + k * 3 + 2];
}
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) +
(z2 - z1) * (z2 - z1);
float d2 = min(d, td);
if (d2 != td)
temp[blockIdx.x * n + k] = d2;
if (d2 > best) {
best = d2;
besti = k;
}
}
dists[threadIdx.x] = best;
dists_i[threadIdx.x] = besti;
for (int u = 0; (1 << u) < blockDim.x; u++) {
__syncthreads();
if (threadIdx.x < (blockDim.x >> (u + 1))) {
int i1 = (threadIdx.x * 2) << u;
int i2 = (threadIdx.x * 2 + 1) << u;
if (dists[i1] < dists[i2]) {
dists[i1] = dists[i2];
dists_i[i1] = dists_i[i2];
}
}
}
__syncthreads();
old = dists_i[0];
if (threadIdx.x == 0)
idxs[i * m + j] = old;
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment