Commit 7318c1b8 authored by rusty1s's avatar rusty1s
Browse files

clean up cosine distance

parent 9947d77e
...@@ -4,26 +4,24 @@ ...@@ -4,26 +4,24 @@
#define THREADS 1024 #define THREADS 1024
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu template <typename scalar_t> struct Cosine {
template <typename scalar_t> static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
__global__ void size_t size) {
dot(double *a, double *b, size_t size) { scalar_t result = 0;
double result = 0; for (ptrdiff_t i = 0; i < size; i++) {
result += a[i] * b[i];
for(int i = 0; i < size; i++) {
result += a[i] * b[i];
} }
return result; return result;
} }
template <typename scalar_t> static inline __device__ scalar_t norm(const scalar_t *a, size_t size) {
__global__ void scalar_t result = 0;
norm(double *a, size_t size) { for (ptrdiff_t i = 0; i < size; i++) {
double result = dot(a,a,size); result += a[i] * a[i];
result = sqrt(result); }
return result; return sqrt(result);
} }
};
template <typename scalar_t> template <typename scalar_t>
__global__ void __global__ void
...@@ -52,16 +50,16 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, ...@@ -52,16 +50,16 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
scalar_t tmp_dist = 0; scalar_t tmp_dist = 0;
if (cosine) { if (cosine) {
tmp_dist = norm(x,dim)*norm(y,dim)-dot(x,y,dim) tmp_dist =
} Cosine<scalar_t>::norm(x, dim) * Cosine<scalar_t>::norm(y, dim) -
else { Cosine<scalar_t>::dot(x, y, dim);
} else {
for (ptrdiff_t d = 0; d < dim; d++) { for (ptrdiff_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) * tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]); (x[n_x * dim + d] - y[n_y * dim + d]);
} }
} }
for (ptrdiff_t k_idx_1 = 0; k_idx_1 < k; k_idx_1++) { for (ptrdiff_t k_idx_1 = 0; k_idx_1 < k; k_idx_1++) {
if (dist[n_y * k + k_idx_1] > tmp_dist) { if (dist[n_y * k + k_idx_1] > tmp_dist) {
for (ptrdiff_t k_idx_2 = k - 1; k_idx_2 > k_idx_1; k_idx_2--) { for (ptrdiff_t k_idx_2 = k - 1; k_idx_2 > k_idx_1; k_idx_2--) {
......
...@@ -31,7 +31,7 @@ if CUDA_HOME is not None: ...@@ -31,7 +31,7 @@ if CUDA_HOME is not None:
__version__ = '1.4.3a1' __version__ = '1.4.3a1'
url = 'https://github.com/rusty1s/pytorch_cluster' url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy', 'scikit-learn'] install_requires = ['scipy']
setup_requires = ['pytest-runner'] setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov'] tests_require = ['pytest', 'pytest-cov']
......
...@@ -33,6 +33,11 @@ def test_knn(dtype, device): ...@@ -33,6 +33,11 @@ def test_knn(dtype, device):
assert row.tolist() == [0, 0, 1, 1] assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [2, 3, 4, 5] assert col.tolist() == [2, 3, 4, 5]
if x.is_cuda:
row, col = knn(x, y, 2, batch_x, batch_y, cosine=True)
assert row.tolist() == [0, 0, 1, 1]
assert col.tolist() == [0, 1, 4, 5]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph(dtype, device): def test_knn_graph(dtype, device):
......
import torch import torch
import scipy.spatial import scipy.spatial
import sklearn.neighbors
if torch.cuda.is_available(): if torch.cuda.is_available():
import torch_cluster.knn_cuda import torch_cluster.knn_cuda
...@@ -22,6 +21,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False): ...@@ -22,6 +21,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
batch_y (LongTensor, optional): Batch vector batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`) node to a specific example. (default: :obj:`None`)
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -57,6 +59,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False): ...@@ -57,6 +59,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
if x.is_cuda: if x.is_cuda:
return torch_cluster.knn_cuda.knn(x, y, k, batch_x, batch_y, cosine) return torch_cluster.knn_cuda.knn(x, y, k, batch_x, batch_y, cosine)
if cosine:
raise NotImplementedError('Cosine distance not implemented for CPU')
# Rescale x and y. # Rescale x and y.
min_xy = min(x.min().item(), y.min().item()) min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy x, y = x - min_xy, y - min_xy
...@@ -68,14 +73,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False): ...@@ -68,14 +73,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1) x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1) y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
query_opts=dict(k=k) tree = scipy.spatial.cKDTree(x.detach().numpy())
if cosine: dist, col = tree.query(y.detach().cpu(), k=k,
tree = sklearn.neighbors.KDTree(x.detach().numpy(), metric='cosine') distance_upper_bound=x.size(1))
else:
tree = scipy.spatial.cKDTree(x.detach().numpy())
query_opts['distance_upper_bound']=x.size(1)
dist, col = tree.query(
y.detach().cpu(), **query_opts)
dist = torch.from_numpy(dist).to(x.dtype) dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long) col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k) row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k)
...@@ -85,7 +85,8 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False): ...@@ -85,7 +85,8 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
return torch.stack([row, col], dim=0) return torch.stack([row, col], dim=0)
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target', cosine=False): def knn_graph(x, k, batch=None, loop=False, flow='source_to_target',
cosine=False):
r"""Computes graph edges to the nearest :obj:`k` points. r"""Computes graph edges to the nearest :obj:`k` points.
Args: Args:
...@@ -100,6 +101,9 @@ def knn_graph(x, k, batch=None, loop=False, flow='source_to_target', cosine=Fals ...@@ -100,6 +101,9 @@ def knn_graph(x, k, batch=None, loop=False, flow='source_to_target', cosine=Fals
flow (string, optional): The flow direction when using in combination flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment