Commit 37c5c8e0 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

Linter, deprecated type()

Summary: Run linter after recent changes. Fix long comment in knn.h which clang-format has reflowed badly. Add crude test that code doesn't call deprecated `.type()` or `.data()`.

Reviewed By: nikhilaravi

Differential Revision: D20692935

fbshipit-source-id: 28ce0308adae79a870cb41a810b7cf8744f41ab8
parent 3061c5b6
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
// //
// This file provides utilities for dispatching to specialized versions of functions. // This file provides utilities for dispatching to specialized versions of
// This is especially useful for CUDA kernels, since specializing them to particular // functions. This is especially useful for CUDA kernels, since specializing
// input sizes can often allow the compiler to unroll loops and place arrays into // them to particular input sizes can often allow the compiler to unroll loops
// registers, which can give huge performance speedups. // and place arrays into registers, which can give huge performance speedups.
// //
// As an example, suppose we have the following function which is specialized // As an example, suppose we have the following function which is specialized
// based on a compile-time int64_t value: // based on a compile-time int64_t value:
...@@ -92,14 +92,13 @@ namespace { ...@@ -92,14 +92,13 @@ namespace {
// In order to dispatch, we will take an additional template argument curN, // In order to dispatch, we will take an additional template argument curN,
// and increment it via template recursion until it is equal to the run-time // and increment it via template recursion until it is equal to the run-time
// argument N. // argument N.
template< template <
template<typename, int64_t> class Kernel, template <typename, int64_t> class Kernel,
typename T, typename T,
int64_t minN, int64_t minN,
int64_t maxN, int64_t maxN,
int64_t curN, int64_t curN,
typename... Args typename... Args>
>
struct DispatchKernelHelper1D { struct DispatchKernelHelper1D {
static void run(const int64_t N, Args... args) { static void run(const int64_t N, Args... args) {
if (curN == N) { if (curN == N) {
...@@ -108,22 +107,21 @@ struct DispatchKernelHelper1D { ...@@ -108,22 +107,21 @@ struct DispatchKernelHelper1D {
Kernel<T, curN>::run(args...); Kernel<T, curN>::run(args...);
} else if (curN < N) { } else if (curN < N) {
// Increment curN via template recursion // Increment curN via template recursion
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...); DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(
N, args...);
} }
// We shouldn't get here -- throw an error? // We shouldn't get here -- throw an error?
} }
}; };
// 1D dispatch: Specialization when curN == maxN // 1D dispatch: Specialization when curN == maxN
// We need this base case to avoid infinite template recursion. // We need this base case to avoid infinite template recursion.
template< template <
template<typename, int64_t> class Kernel, template <typename, int64_t> class Kernel,
typename T, typename T,
int64_t minN, int64_t minN,
int64_t maxN, int64_t maxN,
typename... Args typename... Args>
>
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> { struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
static void run(const int64_t N, Args... args) { static void run(const int64_t N, Args... args) {
if (N == maxN) { if (N == maxN) {
...@@ -133,19 +131,21 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> { ...@@ -133,19 +131,21 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
} }
}; };
// 2D dispatch, general case. // 2D dispatch, general case.
// This is similar to the 1D case: we take additional template args curN and // This is similar to the 1D case: we take additional template args curN and
// curM, and increment them via template recursion until they are equal to // curM, and increment them via template recursion until they are equal to
// the run-time values of N and M, at which point we dispatch to the run // the run-time values of N and M, at which point we dispatch to the run
// method of the kernel. // method of the kernel.
template< template <
template<typename, int64_t, int64_t> class Kernel, template <typename, int64_t, int64_t> class Kernel,
typename T, typename T,
int64_t minN, int64_t maxN, int64_t curN, int64_t minN,
int64_t minM, int64_t maxM, int64_t curM, int64_t maxN,
typename... Args int64_t curN,
> int64_t minM,
int64_t maxM,
int64_t curM,
typename... Args>
struct DispatchKernelHelper2D { struct DispatchKernelHelper2D {
static void run(const int64_t N, const int64_t M, Args... args) { static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && curM == M) { if (curN == N && curM == M) {
...@@ -154,67 +154,141 @@ struct DispatchKernelHelper2D { ...@@ -154,67 +154,141 @@ struct DispatchKernelHelper2D {
// Increment both curN and curM. This isn't strictly necessary; we could // Increment both curN and curM. This isn't strictly necessary; we could
// just increment one or the other at each step. But this helps to cut // just increment one or the other at each step. But this helps to cut
// on the number of recursive calls we make. // on the number of recursive calls we make.
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...); DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
} else if (curN < N) { } else if (curN < N) {
// Increment curN only // Increment curN only
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...); DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM,
Args...>::run(N, M, args...);
} else if (curM < M) { } else if (curM < M) {
// Increment curM only // Increment curM only
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...); DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
} }
} }
}; };
// 2D dispatch, specialization for curN == maxN // 2D dispatch, specialization for curN == maxN
template< template <
template<typename, int64_t, int64_t> class Kernel, template <typename, int64_t, int64_t> class Kernel,
typename T, typename T,
int64_t minN, int64_t maxN, int64_t minN,
int64_t minM, int64_t maxM, int64_t curM, int64_t maxN,
typename... Args int64_t minM,
> int64_t maxM,
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> { int64_t curM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) { static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && curM == M) { if (maxN == N && curM == M) {
Kernel<T, maxN, curM>::run(args...); Kernel<T, maxN, curM>::run(args...);
} else if (curM < maxM) { } else if (curM < maxM) {
DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...); DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
} }
// We should not get here -- throw an error? // We should not get here -- throw an error?
} }
}; };
// 2D dispatch, specialization for curM == maxM // 2D dispatch, specialization for curM == maxM
template< template <
template<typename, int64_t, int64_t> class Kernel, template <typename, int64_t, int64_t> class Kernel,
typename T, typename T,
int64_t minN, int64_t maxN, int64_t curN, int64_t minN,
int64_t minM, int64_t maxM, int64_t maxN,
typename... Args int64_t curN,
> int64_t minM,
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> { int64_t maxM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) { static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && maxM == M) { if (curN == N && maxM == M) {
Kernel<T, curN, maxM>::run(args...); Kernel<T, curN, maxM>::run(args...);
} else if (curN < maxN) { } else if (curN < maxN) {
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...); DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
maxM,
Args...>::run(N, M, args...);
} }
// We should not get here -- throw an error? // We should not get here -- throw an error?
} }
}; };
// 2D dispatch, specialization for curN == maxN, curM == maxM // 2D dispatch, specialization for curN == maxN, curM == maxM
template< template <
template<typename, int64_t, int64_t> class Kernel, template <typename, int64_t, int64_t> class Kernel,
typename T, typename T,
int64_t minN, int64_t maxN, int64_t minN,
int64_t minM, int64_t maxM, int64_t maxN,
typename... Args int64_t minM,
> int64_t maxM,
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> { typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) { static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && maxM == M) { if (maxN == N && maxM == M) {
Kernel<T, maxN, maxM>::run(args...); Kernel<T, maxN, maxM>::run(args...);
...@@ -225,37 +299,45 @@ struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Arg ...@@ -225,37 +299,45 @@ struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Arg
} // namespace } // namespace
// This is the function we expect users to call to dispatch to 1D functions // This is the function we expect users to call to dispatch to 1D functions
template< template <
template<typename, int64_t> class Kernel, template <typename, int64_t> class Kernel,
typename T, typename T,
int64_t minN, int64_t minN,
int64_t maxN, int64_t maxN,
typename... Args typename... Args>
>
void DispatchKernel1D(const int64_t N, Args... args) { void DispatchKernel1D(const int64_t N, Args... args) {
if (minN <= N && N <= maxN) { if (minN <= N && N <= maxN) {
// Kick off the template recursion by calling the Helper with curN = minN // Kick off the template recursion by calling the Helper with curN = minN
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...); DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(
N, args...);
} }
// Maybe throw an error if we tried to dispatch outside the allowed range? // Maybe throw an error if we tried to dispatch outside the allowed range?
} }
// This is the function we expect users to call to dispatch to 2D functions // This is the function we expect users to call to dispatch to 2D functions
template< template <
template<typename, int64_t, int64_t> class Kernel, template <typename, int64_t, int64_t> class Kernel,
typename T, typename T,
int64_t minN, int64_t maxN, int64_t minN,
int64_t minM, int64_t maxM, int64_t maxN,
typename... Args int64_t minM,
> int64_t maxM,
typename... Args>
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) { void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
if (minN <= N && N <= maxN && minM <= M && M <= maxM) { if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
// Kick off the template recursion by calling the Helper with curN = minN // Kick off the template recursion by calling the Helper with curN = minN
// and curM = minM // and curM = minM
DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...); DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
minN,
minM,
maxM,
minM,
Args...>::run(N, M, args...);
} }
// Maybe throw an error if we tried to dispatch outside the specified range? // Maybe throw an error if we tried to dispatch outside the specified range?
} }
...@@ -39,82 +39,180 @@ ...@@ -39,82 +39,180 @@
// approach for this might lead to extra function calls at runtime if the // approach for this might lead to extra function calls at runtime if the
// compiler fails to optimize them away, which could be very slow on device. // compiler fails to optimize them away, which could be very slow on device.
// However I didn't actually benchmark or test this. // However I didn't actually benchmark or test this.
template<typename T, int N> template <typename T, int N>
struct RegisterIndexUtils { struct RegisterIndexUtils {
__device__ __forceinline__ static T get(const T arr[N], int idx) { __device__ __forceinline__ static T get(const T arr[N], int idx) {
if (idx < 0 || idx >= N) return T(); if (idx < 0 || idx >= N)
return T();
switch (idx) { switch (idx) {
case 0: return arr[0]; case 0:
case 1: return arr[1]; return arr[0];
case 2: return arr[2]; case 1:
case 3: return arr[3]; return arr[1];
case 4: return arr[4]; case 2:
case 5: return arr[5]; return arr[2];
case 6: return arr[6]; case 3:
case 7: return arr[7]; return arr[3];
case 8: return arr[8]; case 4:
case 9: return arr[9]; return arr[4];
case 10: return arr[10]; case 5:
case 11: return arr[11]; return arr[5];
case 12: return arr[12]; case 6:
case 13: return arr[13]; return arr[6];
case 14: return arr[14]; case 7:
case 15: return arr[15]; return arr[7];
case 16: return arr[16]; case 8:
case 17: return arr[17]; return arr[8];
case 18: return arr[18]; case 9:
case 19: return arr[19]; return arr[9];
case 20: return arr[20]; case 10:
case 21: return arr[21]; return arr[10];
case 22: return arr[22]; case 11:
case 23: return arr[23]; return arr[11];
case 24: return arr[24]; case 12:
case 25: return arr[25]; return arr[12];
case 26: return arr[26]; case 13:
case 27: return arr[27]; return arr[13];
case 28: return arr[28]; case 14:
case 29: return arr[29]; return arr[14];
case 30: return arr[30]; case 15:
case 31: return arr[31]; return arr[15];
case 16:
return arr[16];
case 17:
return arr[17];
case 18:
return arr[18];
case 19:
return arr[19];
case 20:
return arr[20];
case 21:
return arr[21];
case 22:
return arr[22];
case 23:
return arr[23];
case 24:
return arr[24];
case 25:
return arr[25];
case 26:
return arr[26];
case 27:
return arr[27];
case 28:
return arr[28];
case 29:
return arr[29];
case 30:
return arr[30];
case 31:
return arr[31];
}; };
return T(); return T();
} }
__device__ __forceinline__ static void set(T arr[N], int idx, T val) { __device__ __forceinline__ static void set(T arr[N], int idx, T val) {
if (idx < 0 || idx >= N) return; if (idx < 0 || idx >= N)
return;
switch (idx) { switch (idx) {
case 0: arr[0] = val; break; case 0:
case 1: arr[1] = val; break; arr[0] = val;
case 2: arr[2] = val; break; break;
case 3: arr[3] = val; break; case 1:
case 4: arr[4] = val; break; arr[1] = val;
case 5: arr[5] = val; break; break;
case 6: arr[6] = val; break; case 2:
case 7: arr[7] = val; break; arr[2] = val;
case 8: arr[8] = val; break; break;
case 9: arr[9] = val; break; case 3:
case 10: arr[10] = val; break; arr[3] = val;
case 11: arr[11] = val; break; break;
case 12: arr[12] = val; break; case 4:
case 13: arr[13] = val; break; arr[4] = val;
case 14: arr[14] = val; break; break;
case 15: arr[15] = val; break; case 5:
case 16: arr[16] = val; break; arr[5] = val;
case 17: arr[17] = val; break; break;
case 18: arr[18] = val; break; case 6:
case 19: arr[19] = val; break; arr[6] = val;
case 20: arr[20] = val; break; break;
case 21: arr[21] = val; break; case 7:
case 22: arr[22] = val; break; arr[7] = val;
case 23: arr[23] = val; break; break;
case 24: arr[24] = val; break; case 8:
case 25: arr[25] = val; break; arr[8] = val;
case 26: arr[26] = val; break; break;
case 27: arr[27] = val; break; case 9:
case 28: arr[28] = val; break; arr[9] = val;
case 29: arr[29] = val; break; break;
case 30: arr[30] = val; break; case 10:
case 31: arr[31] = val; break; arr[10] = val;
break;
case 11:
arr[11] = val;
break;
case 12:
arr[12] = val;
break;
case 13:
arr[13] = val;
break;
case 14:
arr[14] = val;
break;
case 15:
arr[15] = val;
break;
case 16:
arr[16] = val;
break;
case 17:
arr[17] = val;
break;
case 18:
arr[18] = val;
break;
case 19:
arr[19] = val;
break;
case 20:
arr[20] = val;
break;
case 21:
arr[21] = val;
break;
case 22:
arr[22] = val;
break;
case 23:
arr[23] = val;
break;
case 24:
arr[24] = val;
break;
case 25:
arr[25] = val;
break;
case 26:
arr[26] = val;
break;
case 27:
arr[27] = val;
break;
case 28:
arr[28] = val;
break;
case 29:
arr[29] = val;
break;
case 30:
arr[30] = val;
break;
case 31:
arr[31] = val;
break;
} }
} }
}; };
...@@ -289,7 +289,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda( ...@@ -289,7 +289,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const size_t threads = 256; const size_t threads = 256;
const size_t blocks = 256; const size_t blocks = 256;
if (version == 0) { if (version == 0) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
KNearestNeighborKernelV0<scalar_t> KNearestNeighborKernelV0<scalar_t>
<<<blocks, threads>>>( <<<blocks, threads>>>(
p1.data_ptr<scalar_t>(), p1.data_ptr<scalar_t>(),
...@@ -303,7 +303,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda( ...@@ -303,7 +303,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
K); K);
})); }));
} else if (version == 1) { } else if (version == 1) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel1D< DispatchKernel1D<
KNearestNeighborV1Functor, KNearestNeighborV1Functor,
scalar_t, scalar_t,
...@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda( ...@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
K); K);
})); }));
} else if (version == 2) { } else if (version == 2) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D< DispatchKernel2D<
KNearestNeighborKernelV2Functor, KNearestNeighborKernelV2Functor,
scalar_t, scalar_t,
...@@ -343,7 +343,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda( ...@@ -343,7 +343,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
P2); P2);
})); }));
} else if (version == 3) { } else if (version == 3) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D< DispatchKernel2D<
KNearestNeighborKernelV3Functor, KNearestNeighborKernelV3Functor,
scalar_t, scalar_t,
......
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
// containing P1 points of dimension D. // containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each // p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D. // containing P2 points of dimension D.
// K: int giving the number of nearest points to return. // K: int giving the number of nearest points to return.
// sorted: bool telling whether to sort the K returned points by their // sorted: bool telling whether to sort the K returned points by their
// distance version: Integer telling which implementation to use. // distance.
// TODO(jcjohns): Document this more, or maybe remove it before // version: Integer telling which implementation to use.
// landing. // TODO(jcjohns): Document this more, or maybe remove it before landing.
// //
// Returns: // Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where // p1_neighbor_idx: LongTensor of shape (N, P1, K), where
...@@ -41,7 +41,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx( ...@@ -41,7 +41,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
const at::Tensor& p2, const at::Tensor& p2,
int K, int K,
int version) { int version) {
if (p1.type().is_cuda() || p2.type().is_cuda()) { if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(p1); CHECK_CONTIGUOUS_CUDA(p1);
CHECK_CONTIGUOUS_CUDA(p2); CHECK_CONTIGUOUS_CUDA(p2);
......
...@@ -4,49 +4,48 @@ ...@@ -4,49 +4,48 @@
#include <queue> #include <queue>
#include <tuple> #include <tuple>
std::tuple<at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor>
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) { KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
const int N = p1.size(0); const int N = p1.size(0);
const int P1 = p1.size(1); const int P1 = p1.size(1);
const int D = p1.size(2); const int D = p1.size(2);
const int P2 = p2.size(1); const int P2 = p2.size(1);
auto long_opts = p1.options().dtype(torch::kInt64); auto long_opts = p1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts); torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
auto p1_a = p1.accessor<float, 3>(); auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>(); auto p2_a = p2.accessor<float, 3>();
auto idxs_a = idxs.accessor<int64_t, 3>(); auto idxs_a = idxs.accessor<int64_t, 3>();
auto dists_a = dists.accessor<float, 3>(); auto dists_a = dists.accessor<float, 3>();
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int i1 = 0; i1 < P1; ++i1) { for (int i1 = 0; i1 < P1; ++i1) {
// Use a priority queue to store (distance, index) tuples. // Use a priority queue to store (distance, index) tuples.
std::priority_queue<std::tuple<float, int>> q; std::priority_queue<std::tuple<float, int>> q;
for (int i2 = 0; i2 < P2; ++i2) { for (int i2 = 0; i2 < P2; ++i2) {
float dist = 0; float dist = 0;
for (int d = 0; d < D; ++d) { for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
dist += diff * diff; dist += diff * diff;
} }
int size = static_cast<int>(q.size()); int size = static_cast<int>(q.size());
if (size < K || dist < std::get<0>(q.top())) { if (size < K || dist < std::get<0>(q.top())) {
q.emplace(dist, i2); q.emplace(dist, i2);
if (size >= K) { if (size >= K) {
q.pop(); q.pop();
} }
}
}
while (!q.empty()) {
auto t = q.top();
q.pop();
const int k = q.size();
dists_a[n][i1][k] = std::get<0>(t);
idxs_a[n][i1][k] = std::get<1>(t);
}
} }
}
while (!q.empty()) {
auto t = q.top();
q.pop();
const int k = q.size();
dists_a[n][i1][k] = std::get<0>(t);
idxs_a[n][i1][k] = std::get<1>(t);
}
} }
return std::make_tuple(idxs, dists); }
return std::make_tuple(idxs, dists);
} }
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include "index_utils.cuh" #include "index_utils.cuh"
// A data structure to keep track of the smallest K keys seen so far as well // A data structure to keep track of the smallest K keys seen so far as well
// as their associated values, intended to be used in device code. // as their associated values, intended to be used in device code.
// This data structure doesn't allocate any memory; keys and values are stored // This data structure doesn't allocate any memory; keys and values are stored
...@@ -32,18 +31,17 @@ ...@@ -32,18 +31,17 @@
// float key_k = keys[k]; // float key_k = keys[k];
// int value_k = values[k]; // int value_k = values[k];
// } // }
template<typename key_t, typename value_t> template <typename key_t, typename value_t>
class MinK { class MinK {
public: public:
// Constructor. // Constructor.
// //
// Arguments: // Arguments:
// keys: Array in which to store keys // keys: Array in which to store keys
// values: Array in which to store values // values: Array in which to store values
// K: How many values to keep track of // K: How many values to keep track of
__device__ MinK(key_t *keys, value_t *vals, int K) : __device__ MinK(key_t* keys, value_t* vals, int K)
keys(keys), vals(vals), K(K), _size(0) { } : keys(keys), vals(vals), K(K), _size(0) {}
// Try to add a new key and associated value to the data structure. If the key // Try to add a new key and associated value to the data structure. If the key
// is one of the smallest K seen so far then it will be kept; otherwise it // is one of the smallest K seen so far then it will be kept; otherwise it
...@@ -55,7 +53,7 @@ class MinK { ...@@ -55,7 +53,7 @@ class MinK {
// Arguments: // Arguments:
// key: The key to add // key: The key to add
// val: The value associated to the key // val: The value associated to the key
__device__ __forceinline__ void add(const key_t &key, const value_t &val) { __device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) { if (_size < K) {
keys[_size] = key; keys[_size] = key;
vals[_size] = val; vals[_size] = val;
...@@ -71,8 +69,8 @@ class MinK { ...@@ -71,8 +69,8 @@ class MinK {
for (int k = 0; k < K; ++k) { for (int k = 0; k < K; ++k) {
key_t cur_key = keys[k]; key_t cur_key = keys[k];
if (cur_key > max_key) { if (cur_key > max_key) {
max_key = cur_key; max_key = cur_key;
max_idx = k; max_idx = k;
} }
} }
} }
...@@ -102,15 +100,14 @@ class MinK { ...@@ -102,15 +100,14 @@ class MinK {
} }
private: private:
key_t *keys; key_t* keys;
value_t *vals; value_t* vals;
int K; int K;
int _size; int _size;
key_t max_key; key_t max_key;
int max_idx; int max_idx;
}; };
// This is a version of MinK that only touches the arrays using static indexing // This is a version of MinK that only touches the arrays using static indexing
// via RegisterIndexUtils. If the keys and values are stored in thread-local // via RegisterIndexUtils. If the keys and values are stored in thread-local
// arrays, then this may allow the compiler to place them in registers for // arrays, then this may allow the compiler to place them in registers for
...@@ -120,13 +117,13 @@ class MinK { ...@@ -120,13 +117,13 @@ class MinK {
// We found that sorting via RegisterIndexUtils gave very poor performance, // We found that sorting via RegisterIndexUtils gave very poor performance,
// and suspect it may have prevented the compiler from placing the arrays // and suspect it may have prevented the compiler from placing the arrays
// into registers. // into registers.
template<typename key_t, typename value_t, int K> template <typename key_t, typename value_t, int K>
class RegisterMinK { class RegisterMinK {
public: public:
__device__ RegisterMinK(key_t *keys, value_t *vals) : __device__ RegisterMinK(key_t* keys, value_t* vals)
keys(keys), vals(vals), _size(0) {} : keys(keys), vals(vals), _size(0) {}
__device__ __forceinline__ void add(const key_t &key, const value_t &val) { __device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) { if (_size < K) {
RegisterIndexUtils<key_t, K>::set(keys, _size, key); RegisterIndexUtils<key_t, K>::set(keys, _size, key);
RegisterIndexUtils<value_t, K>::set(vals, _size, val); RegisterIndexUtils<value_t, K>::set(vals, _size, val);
...@@ -154,8 +151,8 @@ class RegisterMinK { ...@@ -154,8 +151,8 @@ class RegisterMinK {
} }
private: private:
key_t *keys; key_t* keys;
value_t *vals; value_t* vals;
int _size; int _size;
key_t max_key; key_t max_key;
int max_idx; int max_idx;
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch import torch
from pytorch3d import _C from pytorch3d import _C
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product from itertools import product
import torch import torch
from fvcore.common.benchmark import benchmark from fvcore.common.benchmark import benchmark
...@@ -30,21 +29,13 @@ def benchmark_knn_cuda_versions() -> None: ...@@ -30,21 +29,13 @@ def benchmark_knn_cuda_versions() -> None:
continue continue
if version == 3 and K > 4: if version == 3 and K > 4:
continue continue
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version}) knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
for N, P, D in product(Ns, Ps, Ds): for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({'N': N, 'D': D, 'P': P}) nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark(
knn_cuda_with_init,
'KNN_CUDA_VERSIONS',
knn_kwargs,
warmup_iters=1,
)
benchmark( benchmark(
nn_cuda_with_init, knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1
'NN_CUDA',
nn_kwargs,
warmup_iters=1,
) )
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
def benchmark_knn_cuda_vs_naive() -> None: def benchmark_knn_cuda_vs_naive() -> None:
...@@ -55,21 +46,16 @@ def benchmark_knn_cuda_vs_naive() -> None: ...@@ -55,21 +46,16 @@ def benchmark_knn_cuda_vs_naive() -> None:
Ks = [1, 2, 4, 8, 16] Ks = [1, 2, 4, 8, 16]
knn_kwargs, naive_kwargs = [], [] knn_kwargs, naive_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks): for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K}) knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
if P <= 4096: if P <= 4096:
naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K}) naive_kwargs.append({"N": N, "D": D, "P": P, "K": K})
benchmark( benchmark(
knn_python_cuda_with_init, knn_python_cuda_with_init,
'KNN_CUDA_PYTHON', "KNN_CUDA_PYTHON",
naive_kwargs, naive_kwargs,
warmup_iters=1, warmup_iters=1,
) )
benchmark( benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1)
knn_cuda_with_init,
'KNN_CUDA',
knn_kwargs,
warmup_iters=1,
)
def benchmark_knn_cpu() -> None: def benchmark_knn_cpu() -> None:
...@@ -79,31 +65,18 @@ def benchmark_knn_cpu() -> None: ...@@ -79,31 +65,18 @@ def benchmark_knn_cpu() -> None:
Ks = [1, 2, 4] Ks = [1, 2, 4]
knn_kwargs, nn_kwargs = [], [] knn_kwargs, nn_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks): for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K}) knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
for N, P, D in product(Ns, Ps, Ds): for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({'N': N, 'D': D, 'P': P}) nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark( benchmark(
knn_python_cpu_with_init, knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1
'KNN_CPU_PYTHON',
knn_kwargs,
warmup_iters=1,
)
benchmark(
knn_cpu_with_init,
'KNN_CPU_CPP',
knn_kwargs,
warmup_iters=1,
)
benchmark(
nn_cpu_with_init,
'NN_CPU_CPP',
nn_kwargs,
warmup_iters=1,
) )
benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1)
benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1)
def knn_cuda_with_init(N, D, P, K, v=-1): def knn_cuda_with_init(N, D, P, K, v=-1):
device = torch.device('cuda:0') device = torch.device("cuda:0")
x = torch.randn(N, P, D, device=device) x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -116,7 +89,7 @@ def knn_cuda_with_init(N, D, P, K, v=-1): ...@@ -116,7 +89,7 @@ def knn_cuda_with_init(N, D, P, K, v=-1):
def knn_cpu_with_init(N, D, P, K): def knn_cpu_with_init(N, D, P, K):
device = torch.device('cpu') device = torch.device("cpu")
x = torch.randn(N, P, D, device=device) x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device)
...@@ -127,7 +100,7 @@ def knn_cpu_with_init(N, D, P, K): ...@@ -127,7 +100,7 @@ def knn_cpu_with_init(N, D, P, K):
def knn_python_cuda_with_init(N, D, P, K): def knn_python_cuda_with_init(N, D, P, K):
device = torch.device('cuda') device = torch.device("cuda")
x = torch.randn(N, P, D, device=device) x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -140,7 +113,7 @@ def knn_python_cuda_with_init(N, D, P, K): ...@@ -140,7 +113,7 @@ def knn_python_cuda_with_init(N, D, P, K):
def knn_python_cpu_with_init(N, D, P, K): def knn_python_cpu_with_init(N, D, P, K):
device = torch.device('cpu') device = torch.device("cpu")
x = torch.randn(N, P, D, device=device) x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device)
...@@ -151,7 +124,7 @@ def knn_python_cpu_with_init(N, D, P, K): ...@@ -151,7 +124,7 @@ def knn_python_cpu_with_init(N, D, P, K):
def nn_cuda_with_init(N, D, P): def nn_cuda_with_init(N, D, P):
device = torch.device('cuda') device = torch.device("cuda")
x = torch.randn(N, P, D, device=device) x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -164,7 +137,7 @@ def nn_cuda_with_init(N, D, P): ...@@ -164,7 +137,7 @@ def nn_cuda_with_init(N, D, P):
def nn_cpu_with_init(N, D, P): def nn_cpu_with_init(N, D, P):
device = torch.device('cpu') device = torch.device("cpu")
x = torch.randn(N, P, D, device=device) x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device)
......
...@@ -22,6 +22,27 @@ class TestBuild(unittest.TestCase): ...@@ -22,6 +22,27 @@ class TestBuild(unittest.TestCase):
for k, v in counter.items(): for k, v in counter.items():
self.assertEqual(v, 1, f"Too many files with stem {k}.") self.assertEqual(v, 1, f"Too many files with stem {k}.")
def test_deprecated_usage(self):
# Check certain expressions do not occur in the csrc code
test_dir = Path(__file__).resolve().parent
source_dir = test_dir.parent / "pytorch3d" / "csrc"
files = sorted(source_dir.glob("**/*.*"))
self.assertGreater(len(files), 4)
patterns = [".type()", ".data()"]
for file in files:
with open(file) as f:
text = f.read()
for pattern in patterns:
found = pattern in text
msg = (
f"{pattern} found in {file.name}"
+ ", this has been deprecated."
)
self.assertFalse(found, msg)
def test_copyright(self): def test_copyright(self):
test_dir = Path(__file__).resolve().parent test_dir = Path(__file__).resolve().parent
root_dir = test_dir.parent root_dir = test_dir.parent
......
...@@ -28,7 +28,7 @@ class TestKNN(unittest.TestCase): ...@@ -28,7 +28,7 @@ class TestKNN(unittest.TestCase):
def test_knn_vs_python_cpu(self): def test_knn_vs_python_cpu(self):
""" Test CPU output vs PyTorch implementation """ """ Test CPU output vs PyTorch implementation """
device = torch.device('cpu') device = torch.device("cpu")
Ns = [1, 4] Ns = [1, 4]
Ds = [2, 3] Ds = [2, 3]
P1s = [1, 10, 101] P1s = [1, 10, 101]
...@@ -45,7 +45,7 @@ class TestKNN(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestKNN(unittest.TestCase):
def test_knn_vs_python_cuda(self): def test_knn_vs_python_cuda(self):
""" Test CUDA output vs PyTorch implementation """ """ Test CUDA output vs PyTorch implementation """
device = torch.device('cuda') device = torch.device("cuda")
Ns = [1, 4] Ns = [1, 4]
Ds = [2, 3, 8] Ds = [2, 3, 8]
P1s = [1, 8, 64, 128, 1001] P1s = [1, 8, 64, 128, 1001]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment