Unverified Commit 3d47693b authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Op] Farthest Point Sampler in Cpp and CUDA (#1630)

* working framework without actual algorithm logic

* rename

* fix

* fps passes compilation

* correct algorithm

* add cuda implementation

* update random start

* before refactor

* pass compilation but cuda not working

* working

* code working, will add docstring

* add mxnet support

* update docstring

* update doc and test

* cpplint

* cpcplint

* pylint

* temporary fix

* fix for win64

* fix unitetest

* fix

* fix

* remove comment

* move to geometry package

* remove redundant include

* add docstrings and comments

* add proof

* add validity check
parent c8b18b79
......@@ -103,6 +103,8 @@ file(GLOB DGL_SRC
src/kernel/*.cc
src/kernel/cpu/*.cc
src/runtime/*.cc
src/geometry/*.cc
src/geometry/cpu/*.cc
)
file(GLOB_RECURSE DGL_SRC_1
......
......@@ -236,6 +236,7 @@ macro(dgl_config_cuda out_variable)
src/kernel/cuda/*.cc
src/kernel/cuda/*.cu
src/runtime/cuda/*.cc
src/geometry/cuda/*.cu
)
dgl_select_nvcc_arch_flags(NVCC_FLAGS_ARCH)
......
"""Package for geometry common components."""
import importlib
import sys
from ..backend import backend_name
def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items():
setattr(thismod, api, obj)
_load_backend(backend_name)
"""Python interfaces to DGL farthest point sampler."""
from .._ffi.function import _init_api
from .. import backend as F
def farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result):
"""Farthest Point Sampler
Parameters
----------
data : tensor
A tensor of shape (N, d) where N is the number of points and d is the dimension.
batch_size : int
The number of batches in the ``data``. N should be divisible by batch_size.
sample_points : int
The number of points to sample in each batch.
dist : tensor
Pre-allocated tensor of shape (N, ) for to-sample distance.
start_idx : tensor of int
Pre-allocated tensor of shape (batch_size, ) for the starting sample in each batch.
result : tensor of int
Pre-allocated tensor of shape (sample_points * batch_size, ) for the sampled index.
Returns
-------
No return value. The input variable ``result`` will be overwriten with sampled indices.
"""
assert F.shape(data)[0] >= sample_points * batch_size
assert F.shape(data)[0] % batch_size == 0
_CAPI_FarthestPointSampler(F.zerocopy_to_dgl_ndarray(data),
batch_size, sample_points,
F.zerocopy_to_dgl_ndarray(dist),
F.zerocopy_to_dgl_ndarray(start_idx),
F.zerocopy_to_dgl_ndarray(result))
_init_api('dgl.geometry', __name__)
"""Package for mxnet-specific Geometry modules."""
from .fps import *
"""Farthest Point Sampler for mxnet Geometry package"""
#pylint: disable=no-member, invalid-name
from mxnet import nd
from mxnet.gluon import nn
import numpy as np
from ..capi import farthest_point_sampler
class FarthestPointSampler(nn.Block):
"""Farthest Point Sampler
In each batch, the algorithm starts with the sample index specified by ``start_idx``.
Then for each point, we maintain the minimum to-sample distance.
Finally, we pick the point with the maximum such distance.
This process will be repeated for ``sample_points`` - 1 times.
Parameters
----------
npoints : int
The number of points to sample in each batch.
"""
def __init__(self, npoints):
super(FarthestPointSampler, self).__init__()
self.npoints = npoints
def forward(self, pos):
r"""Memory allocation and sampling
Parameters
----------
pos : tensor
The positional tensor of shape (B, N, C)
Returns
-------
tensor of shape (B, self.npoints)
The sampled indices in each batch.
"""
ctx = pos.context
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = nd.zeros((B * N), dtype=pos.dtype, ctx=ctx)
start_idx = nd.random.randint(0, N - 1, (B, ), dtype=np.int, ctx=ctx)
result = nd.zeros((self.npoints * B), dtype=np.int, ctx=ctx)
farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result)
return result.reshape(B, self.npoints)
"""Package for mxnet-specific Geometry modules."""
from .fps import *
"""Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name
import torch as th
from torch import nn
from ..capi import farthest_point_sampler
class FarthestPointSampler(nn.Module):
"""Farthest Point Sampler without the need to compute all pairs of distance.
In each batch, the algorithm starts with the sample index specified by ``start_idx``.
Then for each point, we maintain the minimum to-sample distance.
Finally, we pick the point with the maximum such distance.
This process will be repeated for ``sample_points`` - 1 times.
Parameters
----------
npoints : int
The number of points to sample in each batch.
"""
def __init__(self, npoints):
super(FarthestPointSampler, self).__init__()
self.npoints = npoints
def forward(self, pos):
r"""Memory allocation and sampling
Parameters
----------
pos : tensor
The positional tensor of shape (B, N, C)
Returns
-------
tensor of shape (B, self.npoints)
The sampled indices in each batch.
"""
device = pos.device
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = th.zeros((B * N), dtype=pos.dtype, device=device)
start_idx = th.randint(0, N - 1, (B, ), dtype=th.int, device=device)
result = th.zeros((self.npoints * B), dtype=th.int, device=device)
farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result)
return result.reshape(B, self.npoints)
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/geometry_op_impl.cc
* \brief Geometry operator CPU implementation
*/
#include <dgl/array.h>
#include <numeric>
#include <vector>
namespace dgl {
using runtime::NDArray;
namespace geometry {
namespace impl {
/*!
* \brief Farthest Point Sampler without the need to compute all pairs of distance.
*
* The input array has shape (N, d), where N is the number of points, and d is the dimension.
* It consists of a (flatten) batch of point clouds.
*
* In each batch, the algorithm starts with the sample index specified by ``start_idx``.
* Then for each point, we maintain the minimum to-sample distance.
* Finally, we pick the point with the maximum such distance.
* This process will be repeated for ``sample_points`` - 1 times.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) {
const FloatType* array_data = static_cast<FloatType*>(array->data);
const int64_t point_in_batch = array->shape[0] / batch_size;
const int64_t dim = array->shape[1];
// distance
FloatType* dist_data = static_cast<FloatType*>(dist->data);
// sample for each cloud in the batch
IdType* start_idx_data = static_cast<IdType*>(start_idx->data);
// return value
IdType* ret_data = static_cast<IdType*>(result->data);
int64_t array_start = 0, ret_start = 0;
// loop for each point cloud sample in this batch
for (auto b = 0; b < batch_size; b++) {
// random init start sample
int64_t sample_idx = (int64_t)start_idx_data[b];
ret_data[ret_start] = (IdType)(sample_idx);
// sample the rest `sample_points - 1` points
for (auto i = 0; i < sample_points - 1; i++) {
// re-init distance and the argmax
int64_t dist_argmax = 0;
FloatType dist_max = -1;
// update the distance
for (auto j = 0; j < point_in_batch; j++) {
// compute the distance on dimensions
FloatType one_dist = 0;
for (auto d = 0; d < dim; d++) {
FloatType tmp = array_data[(array_start + j) * dim + d] -
array_data[(array_start + sample_idx) * dim + d];
one_dist += tmp * tmp;
}
// for each out-of-set point, keep its nearest to-the-set distance
if (i == 0 || dist_data[j] > one_dist) {
dist_data[j] = one_dist;
}
// look for the farthest sample
if (dist_data[j] > dist_max) {
dist_argmax = j;
dist_max = dist_data[j];
}
}
// sample the `dist_argmax`-th point
sample_idx = dist_argmax;
ret_data[ret_start + i + 1] = (IdType)(sample_idx);
}
array_start += point_in_batch;
ret_start += sample_points;
}
}
template void FarthestPointSampler<kDLCPU, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLCPU, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLCPU, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLCPU, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
} // namespace impl
} // namespace geometry
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file geometry/cuda/geometry_op_impl.cc
* \brief Geometry operator CUDA implementation
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "../../c_api_common.h"
#include "../geometry_op.h"
#define THREADS 1024
namespace dgl {
namespace geometry {
namespace impl {
/*!
* \brief Farthest Point Sampler without the need to compute all pairs of distance.
*
* The input array has shape (N, d), where N is the number of points, and d is the dimension.
* It consists of a (flatten) batch of point clouds.
*
* In each batch, the algorithm starts with the sample index specified by ``start_idx``.
* Then for each point, we maintain the minimum to-sample distance.
* Finally, we pick the point with the maximum such distance.
* This process will be repeated for ``sample_points`` - 1 times.
*/
template <typename FloatType, typename IdType>
__global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size,
const int64_t sample_points, const int64_t point_in_batch,
const int64_t dim, const IdType *start_idx,
FloatType *dist_data, IdType *ret_data) {
const int64_t thread_idx = threadIdx.x;
const int64_t batch_idx = blockIdx.x;
const int64_t array_start = point_in_batch * batch_idx;
const int64_t ret_start = sample_points * batch_idx;
__shared__ FloatType dist_max_ht[THREADS];
__shared__ int64_t dist_argmax_ht[THREADS];
// start with random initialization
if (thread_idx == 0) {
ret_data[ret_start] = (IdType)(start_idx[batch_idx]);
}
// sample the rest `sample_points - 1` points
for (auto i = 0; i < sample_points - 1; i++) {
__syncthreads();
// the last sampled point
int64_t sample_idx = (int64_t)(ret_data[ret_start + i]);
FloatType dist_max = (FloatType)(-1.);
int64_t dist_argmax = 0;
// multi-thread distance calculation
for (auto j = thread_idx; j < point_in_batch; j += THREADS) {
FloatType one_dist = (FloatType)(0.);
for (auto d = 0; d < dim; d++) {
FloatType tmp = array_data[(array_start + j) * dim + d] -
array_data[(array_start + sample_idx) * dim + d];
one_dist += tmp * tmp;
}
if (i == 0 || dist_data[array_start + j] > one_dist) {
dist_data[array_start + j] = one_dist;
}
if (dist_data[array_start + j] > dist_max) {
dist_argmax = j;
dist_max = dist_data[array_start + j];
}
}
dist_max_ht[thread_idx] = dist_max;
dist_argmax_ht[thread_idx] = dist_argmax;
/*
* \brief Parallel Reduction
*
* Suppose the maximum is dist_max_ht[k], where 0 <= k < THREAD.
* After loop at j = 1, the maximum is propagated to [k-1].
* After loop at j = 2, the maximum is propagated to the range [k-3] to [k].
* After loop at j = 4, the maximum is propagated to the range [k-7] to [k].
* After loop at any j < THREADS, we can see [k - 2*j + 1] to [k] are all covered by the maximum.
* The max value of j is at least floor(THREAD / 2), and it is sufficient to cover [0] with the maximum.
*/
for (auto j = 1; j < THREADS; j *= 2) {
__syncthreads();
if ((thread_idx + j) < THREADS && dist_max_ht[thread_idx] < dist_max_ht[thread_idx + j]) {
dist_max_ht[thread_idx] = dist_max_ht[thread_idx + j];
dist_argmax_ht[thread_idx] = dist_argmax_ht[thread_idx + j];
}
}
if (thread_idx == 0) {
ret_data[ret_start + i + 1] = (IdType)(dist_argmax_ht[0]);
}
}
}
template <DLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const FloatType* array_data = static_cast<FloatType*>(array->data);
const int64_t point_in_batch = array->shape[0] / batch_size;
const int64_t dim = array->shape[1];
// return value
IdType* ret_data = static_cast<IdType*>(result->data);
// distance
FloatType* dist_data = static_cast<FloatType*>(dist->data);
// sample for each cloud in the batch
IdType* start_idx_data = static_cast<IdType*>(start_idx->data);
fps_kernel<<<batch_size, THREADS, 0, thr_entry->stream>>>(
array_data, batch_size, sample_points,
point_in_batch, dim, start_idx_data, dist_data, ret_data);
}
template void FarthestPointSampler<kDLGPU, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLGPU, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLGPU, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLGPU, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
} // namespace impl
} // namespace geometry
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file geometry/geometry.cc
* \brief DGL geometry utilities implementation
*/
#include <dgl/array.h>
#include <dgl/runtime/ndarray.h>
#include "../c_api_common.h"
#include "./geometry_op.h"
using namespace dgl::runtime;
namespace dgl {
namespace geometry {
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) {
CHECK_EQ(array->ctx, result->ctx) << "Array and the result should be on the same device.";
CHECK_EQ(array->shape[0], dist->shape[0]) << "Shape of array and dist mismatch";
CHECK_EQ(start_idx->shape[0], batch_size) << "Shape of start_idx and batch_size mismatch";
CHECK_EQ(result->shape[0], batch_size * sample_points) << "Invalid shape of result";
ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "FarthestPointSampler", {
impl::FarthestPointSampler<XPU, FloatType, IdType>(
array, batch_size, sample_points, dist, start_idx, result);
});
});
});
}
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const NDArray data = args[0];
const int64_t batch_size = args[1];
const int64_t sample_points = args[2];
NDArray dist = args[3];
IdArray start_idx = args[4];
IdArray result = args[5];
FarthestPointSampler(data, batch_size, sample_points, dist, start_idx, result);
});
} // namespace geometry
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file geometry/geometry_op.h
* \brief Geometry operator templates
*/
#ifndef DGL_GEOMETRY_GEOMETRY_OP_H_
#define DGL_GEOMETRY_GEOMETRY_OP_H_
#include <dgl/array.h>
namespace dgl {
namespace geometry {
namespace impl {
template <DLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
} // namespace impl
} // namespace geometry
} // namespace dgl
#endif // DGL_GEOMETRY_GEOMETRY_OP_H_
import mxnet as mx
from dgl.geometry.mxnet import FarthestPointSampler
import backend as F
import numpy as np
def test_fps():
N = 1000
batch_size = 5
sample_points = 10
x = mx.nd.array(np.random.uniform(size=(batch_size, int(N/batch_size), 3)))
ctx = F.ctx()
if F.gpu_ctx():
x = x.as_in_context(ctx)
fps = FarthestPointSampler(sample_points)
res = fps(x)
assert res.shape[0] == batch_size
assert res.shape[1] == sample_points
assert res.sum() > 0
if __name__ == '__main__':
test_fps()
import torch as th
from dgl.geometry.pytorch import FarthestPointSampler
import backend as F
import numpy as np
def test_fps():
N = 1000
batch_size = 5
sample_points = 10
x = th.tensor(np.random.uniform(size=(batch_size, int(N/batch_size), 3)))
ctx = F.ctx()
if F.gpu_ctx():
x = x.to(ctx)
fps = FarthestPointSampler(sample_points)
res = fps(x)
assert res.shape[0] == batch_size
assert res.shape[1] == sample_points
assert res.sum() > 0
if __name__ == '__main__':
test_fps()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment