Commit 2bf5e763 authored by rusty1s's avatar rusty1s
Browse files

fix cosine knn computation bug

parent 2de4d541
......@@ -8,18 +8,20 @@
template <typename scalar_t> struct Cosine {
static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
int64_t n_a, int64_t n_b,
int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[i] * b[i];
result += a[n_a * size + i] * b[n_b * size + i];
}
return result;
}
static inline __device__ scalar_t norm(const scalar_t *a, int64_t size) {
static inline __device__ scalar_t norm(const scalar_t *a, int64_t n_a,
int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[i] * a[i];
result += a[n_a * size + i] * a[n_a * size + i];
}
return sqrt(result);
}
......@@ -50,9 +52,10 @@ __global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist =
Cosine<scalar_t>::norm(x, dim) * Cosine<scalar_t>::norm(y, dim) -
Cosine<scalar_t>::dot(x, y, dim);
tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
(Cosine<scalar_t>::norm(x, n_x, dim) *
Cosine<scalar_t>::norm(y, n_y, dim));
tmp_dist = 1. - tmp_dist;
} else {
for (int64_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
......
......@@ -40,7 +40,7 @@ def test_knn(dtype, device):
if x.is_cuda:
edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True)
assert to_set(edge_index) == set([(0, 0), (0, 1), (1, 4), (1, 5)])
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment