Commit 7318c1b8 authored by rusty1s's avatar rusty1s
Browse files

clean up cosine distance

parent 9947d77e
......@@ -4,26 +4,24 @@
#define THREADS 1024
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
template <typename scalar_t>
__global__ void
dot(double *a, double *b, size_t size) {
double result = 0;
for(int i = 0; i < size; i++) {
template <typename scalar_t> struct Cosine {
static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
size_t size) {
scalar_t result = 0;
for (ptrdiff_t i = 0; i < size; i++) {
result += a[i] * b[i];
}
return result;
}
}
template <typename scalar_t>
__global__ void
norm(double *a, size_t size) {
double result = dot(a,a,size);
result = sqrt(result);
return result;
}
static inline __device__ scalar_t norm(const scalar_t *a, size_t size) {
scalar_t result = 0;
for (ptrdiff_t i = 0; i < size; i++) {
result += a[i] * a[i];
}
return sqrt(result);
}
};
template <typename scalar_t>
__global__ void
......@@ -52,16 +50,16 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist = norm(x,dim)*norm(y,dim)-dot(x,y,dim)
}
else {
tmp_dist =
Cosine<scalar_t>::norm(x, dim) * Cosine<scalar_t>::norm(y, dim) -
Cosine<scalar_t>::dot(x, y, dim);
} else {
for (ptrdiff_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
}
for (ptrdiff_t k_idx_1 = 0; k_idx_1 < k; k_idx_1++) {
if (dist[n_y * k + k_idx_1] > tmp_dist) {
for (ptrdiff_t k_idx_2 = k - 1; k_idx_2 > k_idx_1; k_idx_2--) {
......
......@@ -31,7 +31,7 @@ if CUDA_HOME is not None:
__version__ = '1.4.3a1'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy', 'scikit-learn']
install_requires = ['scipy']
setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov']
......
......@@ -33,6 +33,11 @@ def test_knn(dtype, device):
assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [2, 3, 4, 5]
if x.is_cuda:
row, col = knn(x, y, 2, batch_x, batch_y, cosine=True)
assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [0, 1, 4, 5]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph(dtype, device):
......
import torch
import scipy.spatial
import sklearn.neighbors
if torch.cuda.is_available():
import torch_cluster.knn_cuda
......@@ -22,6 +21,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
:rtype: :class:`LongTensor`
......@@ -57,6 +59,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
if x.is_cuda:
return torch_cluster.knn_cuda.knn(x, y, k, batch_x, batch_y, cosine)
if cosine:
raise NotImplementedError('Cosine distance not implemented for CPU')
# Rescale x and y.
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy
......@@ -68,14 +73,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
query_opts=dict(k=k)
if cosine:
tree = sklearn.neighbors.KDTree(x.detach().numpy(), metric='cosine')
else:
tree = scipy.spatial.cKDTree(x.detach().numpy())
query_opts['distance_upper_bound']=x.size(1)
dist, col = tree.query(
y.detach().cpu(), **query_opts)
dist, col = tree.query(y.detach().cpu(), k=k,
distance_upper_bound=x.size(1))
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k)
......@@ -85,7 +85,8 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
return torch.stack([row, col], dim=0)
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target', cosine=False):
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target',
cosine=False):
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
......@@ -100,6 +101,9 @@ def knn_graph(x, k, batch=None, loop=False, flow='source_to_target', cosine=Fals
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
:rtype: :class:`LongTensor`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment