Commit f7af865f authored by rusty1s's avatar rusty1s
Browse files

typos

parent 5d2168d2
......@@ -11,7 +11,7 @@ template <typename scalar_t> struct Dist<scalar_t, 1> {
static __device__ void
compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
const scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
scalar_t *__restrict__ tmp_dist, size_t dim) {
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
......@@ -29,7 +29,7 @@ template <typename scalar_t> struct Dist<scalar_t, 2> {
static __device__ void
compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
const scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
scalar_t *__restrict__ tmp_dist, size_t dim) {
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
......@@ -49,7 +49,7 @@ template <typename scalar_t> struct Dist<scalar_t, 3> {
static __device__ void
compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
const scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
scalar_t *__restrict__ tmp_dist, size_t dim) {
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
......@@ -70,7 +70,7 @@ template <typename scalar_t> struct Dist<scalar_t, -1> {
static __device__ void
compute(ptrdiff_t idx, ptrdiff_t start_idx, ptrdiff_t end_idx, ptrdiff_t old,
scalar_t *__restrict__ best, ptrdiff_t *__restrict__ best_idx,
scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
const scalar_t *__restrict__ x, scalar_t *__restrict__ dist,
scalar_t *__restrict__ tmp_dist, size_t dim) {
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
......@@ -96,8 +96,8 @@ template <typename scalar_t> struct Dist<scalar_t, -1> {
template <typename scalar_t, int64_t Dim>
__global__ void
fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
int64_t *__restrict__ cum_k, int64_t *__restrict__ start,
fps_kernel(const scalar_t *__restrict__ x, const int64_t *__restrict__ cum_deg,
const int64_t *__restrict__ cum_k, const int64_t *__restrict__ start,
scalar_t *__restrict__ dist, scalar_t *__restrict__ tmp_dist,
int64_t *__restrict__ out, size_t dim) {
......
......@@ -3,12 +3,11 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
std::tuple<at::Tensor, at::Tensor> nearest_cuda(at::Tensor x, at::Tensor y,
at::Tensor batch_x,
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y);
std::tuple<at::Tensor, at::Tensor>
nearest(at::Tensor x, at::Tensor y, at::Tensor batch_x, at::Tensor batch_y) {
at::Tensor nearest(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(y);
......
......@@ -5,11 +5,11 @@
#define THREADS 1024
template <typename scalar_t>
__global__ void
nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y,
int64_t *__restrict__ batch_x, int64_t *__restrict__ batch_y,
scalar_t *__restrict__ out, int64_t *__restrict__ out_idx,
size_t dim) {
__global__ void nearest_kernel(const scalar_t *__restrict__ x,
const scalar_t *__restrict__ y,
const int64_t *__restrict__ batch_x,
const int64_t *__restrict__ batch_y,
int64_t *__restrict__ out, const size_t dim) {
const ptrdiff_t n_x = blockIdx.x;
const ptrdiff_t batch_idx = batch_x[n_x];
......@@ -55,13 +55,11 @@ nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y,
__syncthreads();
if (idx == 0) {
out[n_x] = best_dist[0];
out_idx[n_x] = best_dist_idx[0];
out[n_x] = best_dist_idx[0];
}
}
std::tuple<at::Tensor, at::Tensor> nearest_cuda(at::Tensor x, at::Tensor y,
at::Tensor batch_x,
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y) {
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
......@@ -71,15 +69,13 @@ std::tuple<at::Tensor, at::Tensor> nearest_cuda(at::Tensor x, at::Tensor y,
batch_y = degree(batch_y, batch_size);
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);
auto out = at::empty(x.size(0), x.options());
auto out_idx = at::empty_like(batch_x);
auto out = at::empty_like(batch_x);
AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES(x.type(), "nearest_kernel", [&] {
nearest_kernel<scalar_t><<<x.size(0), THREADS>>>(
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(),
batch_y.data<int64_t>(), out.data<scalar_t>(), out_idx.data<int64_t>(),
x.size(1));
batch_y.data<int64_t>(), out.data<int64_t>(), x.size(1));
});
return std::make_tuple(out, out_idx);
return out;
}
......@@ -32,6 +32,5 @@ def test_nearest(dtype, device):
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 0, 1, 1], torch.long, device)
dist, idx = nearest(x, y, batch_x, batch_y)
assert dist.tolist() == [1, 1, 1, 1, 2, 2, 2, 2]
assert idx.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
out = nearest(x, y, batch_x, batch_y)
assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
......@@ -29,13 +29,13 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
x = x.view(-1, 1) if x.dim() == 1 else x
assert x.is_cuda
assert x.dim() <= 2 and batch.dim() == 1
assert x.dim() == 2 and batch.dim() == 1
assert x.size(0) == batch.size(0)
assert ratio > 0 and ratio < 1
x = x.view(-1, 1) if x.dim() == 1 else x
op = fps_cuda.fps if x.is_cuda else None
out = op(x, batch, ratio, random_start)
......
......@@ -22,6 +22,6 @@ def nearest(x, y, batch_x=None, batch_y=None):
assert y.size(0) == batch_y.size(0)
op = nearest_cuda.nearest if x.is_cuda else None
dist, idx = op(x, y, batch_x, batch_y)
out = op(x, y, batch_x, batch_y)
return dist, idx
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment