Commit d678ae82 authored by jlevy44's avatar jlevy44
Browse files

changed cuda function names, added sklearn as dependency, fixed cpu computation

parent 2388a521
......@@ -5,7 +5,9 @@
#define THREADS 1024
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
__device__ double dotProduct(double *a, double *b, int size) {
template <typename scalar_t>
__global__ void
dot(double *a, double *b, size_t size) {
double result = 0;
for(int i = 0; i < size; i++) {
......@@ -15,8 +17,10 @@ __device__ double dotProduct(double *a, double *b, int size) {
return result;
}
__device__ double calc_norm(double *a, int size) {
double result = dotProduct(a,a,size);
template <typename scalar_t>
__global__ void
norm(double *a, size_t size) {
double result = dot(a,a,size);
result = sqrt(result);
return result;
}
......@@ -48,7 +52,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist = calc_norm(x,dim)*calc_norm(y,dim)-dotProduct(x,y,dim)
tmp_dist = norm(x,dim)*norm(y,dim)-dot(x,y,dim)
}
else {
for (ptrdiff_t d = 0; d < dim; d++) {
......
......@@ -31,7 +31,7 @@ if CUDA_HOME is not None:
__version__ = '1.4.3a1'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy']
install_requires = ['scipy', 'scikit-learn']
setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov']
......
......@@ -68,9 +68,14 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = sklearn.neighbors.KDTree(x.detach().numpy(), metric='cosine' if cosine else 'minkowski')#scipy.spatial.cKDTree(x.detach().numpy())
query_opts=dict(k=k)
if cosine:
tree = sklearn.neighbors.KDTree(x.detach().numpy(), metric='cosine')
else:
tree = scipy.spatial.cKDTree(x.detach().numpy())
query_opts['distance_upper_bound']=x.size(1)
dist, col = tree.query(
y.detach().cpu(), k=k)#, distance_upper_bound=x.size(1))
y.detach().cpu(), **query_opts)
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment