Commit d678ae82 authored by jlevy44's avatar jlevy44
Browse files

changed cuda function names, added sklearn as dependency, fixed cpu computation

parent 2388a521
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
#define THREADS 1024 #define THREADS 1024
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu // Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
__device__ double dotProduct(double *a, double *b, int size) { template <typename scalar_t>
__global__ void
dot(double *a, double *b, size_t size) {
double result = 0; double result = 0;
for(int i = 0; i < size; i++) { for(int i = 0; i < size; i++) {
...@@ -15,8 +17,10 @@ __device__ double dotProduct(double *a, double *b, int size) { ...@@ -15,8 +17,10 @@ __device__ double dotProduct(double *a, double *b, int size) {
return result; return result;
} }
__device__ double calc_norm(double *a, int size) { template <typename scalar_t>
double result = dotProduct(a,a,size); __global__ void
norm(double *a, size_t size) {
double result = dot(a,a,size);
result = sqrt(result); result = sqrt(result);
return result; return result;
} }
...@@ -48,7 +52,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, ...@@ -48,7 +52,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
scalar_t tmp_dist = 0; scalar_t tmp_dist = 0;
if (cosine) { if (cosine) {
tmp_dist = calc_norm(x,dim)*calc_norm(y,dim)-dotProduct(x,y,dim) tmp_dist = norm(x,dim)*norm(y,dim)-dot(x,y,dim)
} }
else { else {
for (ptrdiff_t d = 0; d < dim; d++) { for (ptrdiff_t d = 0; d < dim; d++) {
......
...@@ -31,7 +31,7 @@ if CUDA_HOME is not None: ...@@ -31,7 +31,7 @@ if CUDA_HOME is not None:
__version__ = '1.4.3a1' __version__ = '1.4.3a1'
url = 'https://github.com/rusty1s/pytorch_cluster' url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy'] install_requires = ['scipy', 'scikit-learn']
setup_requires = ['pytest-runner'] setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov'] tests_require = ['pytest', 'pytest-cov']
......
...@@ -68,9 +68,14 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False): ...@@ -68,9 +68,14 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1) x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1) y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = sklearn.neighbors.KDTree(x.detach().numpy(), metric='cosine' if cosine else 'minkowski')#scipy.spatial.cKDTree(x.detach().numpy()) query_opts=dict(k=k)
if cosine:
tree = sklearn.neighbors.KDTree(x.detach().numpy(), metric='cosine')
else:
tree = scipy.spatial.cKDTree(x.detach().numpy())
query_opts['distance_upper_bound']=x.size(1)
dist, col = tree.query( dist, col = tree.query(
y.detach().cpu(), k=k)#, distance_upper_bound=x.size(1)) y.detach().cpu(), **query_opts)
dist = torch.from_numpy(dist).to(x.dtype) dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long) col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k) row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment