"vscode:/vscode.git/clone" did not exist on "56bc006d56a0d4960de6a1e0b6340cba4eda05cd"
Commit 2388a521 authored by jlevy44's avatar jlevy44
Browse files

Attempt at adding cosine similarity metric

parent 453531a5
......@@ -4,17 +4,17 @@
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y);
at::Tensor batch_y, bool cosine);
at::Tensor knn(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y) {
at::Tensor batch_y, bool cosine) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(y);
IS_CONTIGUOUS(y);
CHECK_CUDA(batch_x);
CHECK_CUDA(batch_y);
return knn_cuda(x, y, k, batch_x, batch_y);
return knn_cuda(x, y, k, batch_x, batch_y, cosine);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -4,13 +4,30 @@
#define THREADS 1024
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
__device__ double dotProduct(double *a, double *b, int size) {
double result = 0;
for(int i = 0; i < size; i++) {
result += a[i] * b[i];
}
return result;
}
__device__ double calc_norm(double *a, int size) {
double result = dotProduct(a,a,size);
result = sqrt(result);
return result;
}
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ batch_x,
const int64_t *__restrict__ batch_y, scalar_t *__restrict__ dist,
int64_t *__restrict__ row, int64_t *__restrict__ col, size_t k,
size_t dim) {
size_t dim, bool cosine) {
const ptrdiff_t batch_idx = blockIdx.x;
const ptrdiff_t idx = threadIdx.x;
......@@ -30,10 +47,16 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) {
scalar_t tmp_dist = 0;
for (ptrdiff_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
if (cosine) {
tmp_dist = calc_norm(x,dim)*calc_norm(y,dim)-dotProduct(x,y,dim)
}
else {
for (ptrdiff_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
}
for (ptrdiff_t k_idx_1 = 0; k_idx_1 < k; k_idx_1++) {
if (dist[n_y * k + k_idx_1] > tmp_dist) {
......@@ -51,7 +74,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
}
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y) {
at::Tensor batch_y, bool cosine) {
cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
......@@ -71,7 +94,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
knn_kernel<scalar_t><<<batch_size, THREADS>>>(
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(),
batch_y.data<int64_t>(), dist.data<scalar_t>(), row.data<int64_t>(),
col.data<int64_t>(), k, x.size(1));
col.data<int64_t>(), k, x.size(1), cosine);
});
auto mask = col != -1;
......
import torch
import scipy.spatial
import sklearn.neighbors
if torch.cuda.is_available():
import torch_cluster.knn_cuda
def knn(x, y, k, batch_x=None, batch_y=None):
def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
......@@ -54,7 +55,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
assert y.size(0) == batch_y.size(0)
if x.is_cuda:
return torch_cluster.knn_cuda.knn(x, y, k, batch_x, batch_y)
return torch_cluster.knn_cuda.knn(x, y, k, batch_x, batch_y, cosine)
# Rescale x and y.
min_xy = min(x.min().item(), y.min().item())
......@@ -67,9 +68,9 @@ def knn(x, y, k, batch_x=None, batch_y=None):
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x.detach().numpy())
tree = sklearn.neighbors.KDTree(x.detach().numpy(), metric='cosine' if cosine else 'minkowski')#scipy.spatial.cKDTree(x.detach().numpy())
dist, col = tree.query(
y.detach().cpu(), k=k, distance_upper_bound=x.size(1))
y.detach().cpu(), k=k)#, distance_upper_bound=x.size(1))
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k)
......@@ -79,7 +80,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
return torch.stack([row, col], dim=0)
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target'):
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target', cosine=False):
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
......@@ -110,7 +111,7 @@ def knn_graph(x, k, batch=None, loop=False, flow='source_to_target'):
"""
assert flow in ['source_to_target', 'target_to_source']
row, col = knn(x, x, k if loop else k + 1, batch, batch)
row, col = knn(x, x, k if loop else k + 1, batch, batch, cosine=cosine)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop:
mask = row != col
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment