Unverified Commit 9947d77e authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #28 from jlevy44/master

Attempt at adding cosine similarity metric
parents 453531a5 d678ae82
......@@ -4,17 +4,17 @@
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y);
at::Tensor batch_y, bool cosine);
at::Tensor knn(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y) {
at::Tensor batch_y, bool cosine) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(y);
IS_CONTIGUOUS(y);
CHECK_CUDA(batch_x);
CHECK_CUDA(batch_y);
return knn_cuda(x, y, k, batch_x, batch_y);
return knn_cuda(x, y, k, batch_x, batch_y, cosine);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -4,13 +4,34 @@
#define THREADS 1024
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
template <typename scalar_t>
__global__ void
dot(double *a, double *b, size_t size) {
double result = 0;
for(int i = 0; i < size; i++) {
result += a[i] * b[i];
}
return result;
}
template <typename scalar_t>
__global__ void
norm(double *a, size_t size) {
double result = dot(a,a,size);
result = sqrt(result);
return result;
}
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ batch_x,
const int64_t *__restrict__ batch_y, scalar_t *__restrict__ dist,
int64_t *__restrict__ row, int64_t *__restrict__ col, size_t k,
size_t dim) {
size_t dim, bool cosine) {
const ptrdiff_t batch_idx = blockIdx.x;
const ptrdiff_t idx = threadIdx.x;
......@@ -30,10 +51,16 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) {
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist = norm(x,dim)*norm(y,dim)-dot(x,y,dim)
}
else {
for (ptrdiff_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
}
for (ptrdiff_t k_idx_1 = 0; k_idx_1 < k; k_idx_1++) {
if (dist[n_y * k + k_idx_1] > tmp_dist) {
......@@ -51,7 +78,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
}
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y) {
at::Tensor batch_y, bool cosine) {
cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
......@@ -71,7 +98,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
knn_kernel<scalar_t><<<batch_size, THREADS>>>(
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(),
batch_y.data<int64_t>(), dist.data<scalar_t>(), row.data<int64_t>(),
col.data<int64_t>(), k, x.size(1));
col.data<int64_t>(), k, x.size(1), cosine);
});
auto mask = col != -1;
......
......@@ -31,7 +31,7 @@ if CUDA_HOME is not None:
__version__ = '1.4.3a1'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy']
install_requires = ['scipy', 'scikit-learn']
setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov']
......
import torch
import scipy.spatial
import sklearn.neighbors
if torch.cuda.is_available():
import torch_cluster.knn_cuda
def knn(x, y, k, batch_x=None, batch_y=None):
def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
......@@ -54,7 +55,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
assert y.size(0) == batch_y.size(0)
if x.is_cuda:
return torch_cluster.knn_cuda.knn(x, y, k, batch_x, batch_y)
return torch_cluster.knn_cuda.knn(x, y, k, batch_x, batch_y, cosine)
# Rescale x and y.
min_xy = min(x.min().item(), y.min().item())
......@@ -67,9 +68,14 @@ def knn(x, y, k, batch_x=None, batch_y=None):
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
query_opts=dict(k=k)
if cosine:
tree = sklearn.neighbors.KDTree(x.detach().numpy(), metric='cosine')
else:
tree = scipy.spatial.cKDTree(x.detach().numpy())
query_opts['distance_upper_bound']=x.size(1)
dist, col = tree.query(
y.detach().cpu(), k=k, distance_upper_bound=x.size(1))
y.detach().cpu(), **query_opts)
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k)
......@@ -79,7 +85,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
return torch.stack([row, col], dim=0)
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target'):
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target', cosine=False):
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
......@@ -110,7 +116,7 @@ def knn_graph(x, k, batch=None, loop=False, flow='source_to_target'):
"""
assert flow in ['source_to_target', 'target_to_source']
row, col = knn(x, x, k if loop else k + 1, batch, batch)
row, col = knn(x, x, k if loop else k + 1, batch, batch, cosine=cosine)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop:
mask = row != col
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment