Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
0adaf7f9
Commit
0adaf7f9
authored
Dec 14, 2021
by
rusty1s
Browse files
improve knn performance
parent
442e8d9c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
56 deletions
+66
-56
csrc/cpu/utils.h
csrc/cpu/utils.h
+1
-0
csrc/cuda/knn_cuda.cu
csrc/cuda/knn_cuda.cu
+64
-54
test/test_knn.py
test/test_knn.py
+1
-2
No files found.
csrc/cpu/utils.h
View file @
0adaf7f9
...
@@ -4,3 +4,4 @@
...
@@ -4,3 +4,4 @@
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
csrc/cuda/knn_cuda.cu
View file @
0adaf7f9
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "utils.cuh"
#include "utils.cuh"
#define THREADS
1024
#define THREADS
256
template
<
typename
scalar_t
>
struct
Cosine
{
template
<
typename
scalar_t
>
struct
Cosine
{
static
inline
__device__
scalar_t
dot
(
const
scalar_t
*
a
,
const
scalar_t
*
b
,
static
inline
__device__
scalar_t
dot
(
const
scalar_t
*
a
,
const
scalar_t
*
b
,
...
@@ -27,95 +27,105 @@ template <typename scalar_t> struct Cosine {
...
@@ -27,95 +27,105 @@ template <typename scalar_t> struct Cosine {
}
}
};
};
template
<
typename
scalar_t
>
__device__
int64_t
get_example_idx
(
int64_t
idx
,
const
int64_t
*
ptr
,
__global__
void
knn_kernel
(
const
scalar_t
*
x
,
const
scalar_t
*
y
,
const
int64_t
num_examples
)
{
const
int64_t
*
ptr_x
,
const
int64_t
*
ptr_y
,
for
(
int64_t
i
=
0
;
i
<
num_examples
;
i
++
)
{
scalar_t
*
dist
,
int64_t
*
row
,
int64_t
*
col
,
if
(
ptr
[
i
+
1
]
>
idx
)
int64_t
K
,
int64_t
dim
,
bool
cosine
)
{
return
i
;
}
const
int64_t
batch_idx
=
blockIdx
.
x
;
return
num_examples
-
1
;
}
const
int64_t
x_start_idx
=
ptr_x
[
batch_idx
];
const
int64_t
x_end_idx
=
ptr_x
[
batch_idx
+
1
];
const
int64_t
y_start_idx
=
ptr_y
[
batch_idx
];
const
int64_t
y_end_idx
=
ptr_y
[
batch_idx
+
1
];
for
(
int64_t
n_y
=
y_start_idx
+
threadIdx
.
x
;
n_y
<
y_end_idx
;
n_y
+=
THREADS
)
{
for
(
int64_t
k
=
0
;
k
<
K
;
k
++
)
{
row
[
n_y
*
K
+
k
]
=
n_y
;
}
for
(
int64_t
n_x
=
x_start_idx
;
n_x
<
x_end_idx
;
n_x
++
)
{
template
<
typename
scalar_t
>
__global__
void
scalar_t
tmp_dist
=
0
;
knn_kernel
(
const
scalar_t
*
__restrict__
x
,
const
scalar_t
*
__restrict__
y
,
if
(
cosine
)
{
const
int64_t
*
__restrict__
ptr_x
,
const
int64_t
*
__restrict__
ptr_y
,
tmp_dist
=
Cosine
<
scalar_t
>::
dot
(
x
,
y
,
n_x
,
n_y
,
dim
)
/
scalar_t
*
__restrict__
dist
,
int64_t
*
__restrict__
row
,
(
Cosine
<
scalar_t
>::
norm
(
x
,
n_x
,
dim
)
*
int64_t
*
__restrict__
col
,
const
int64_t
k
,
const
int64_t
n
,
Cosine
<
scalar_t
>::
norm
(
y
,
n_y
,
dim
));
const
int64_t
m
,
const
int64_t
dim
,
const
int64_t
num_examples
,
tmp_dist
=
1.
-
tmp_dist
;
const
bool
cosine
)
{
}
else
{
for
(
int64_t
d
=
0
;
d
<
dim
;
d
++
)
{
const
int64_t
n_y
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
tmp_dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
if
(
n_y
>=
m
)
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
]);
return
;
}
for
(
int64_t
e
=
0
;
e
<
k
;
e
++
)
row
[
n_y
*
k
+
e
]
=
n_y
;
const
int64_t
example_idx
=
get_example_idx
(
n_y
,
ptr_y
,
num_examples
);
for
(
int64_t
n_x
=
ptr_x
[
example_idx
];
n_x
<
ptr_x
[
example_idx
+
1
];
n_x
++
)
{
scalar_t
tmp_dist
=
0
;
if
(
cosine
)
{
tmp_dist
=
Cosine
<
scalar_t
>::
dot
(
x
,
y
,
n_x
,
n_y
,
dim
)
/
(
Cosine
<
scalar_t
>::
norm
(
x
,
n_x
,
dim
)
*
Cosine
<
scalar_t
>::
norm
(
y
,
n_y
,
dim
));
tmp_dist
=
1.
-
tmp_dist
;
}
else
{
for
(
int64_t
d
=
0
;
d
<
dim
;
d
++
)
{
tmp_dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
]);
}
}
}
for
(
int64_t
k_idx_1
=
0
;
k_idx_1
<
K
;
k_idx_1
++
)
{
for
(
int64_t
e1
=
0
;
e1
<
k
;
e1
++
)
{
if
(
dist
[
n_y
*
K
+
k_idx_1
]
>
tmp_dist
)
{
if
(
dist
[
n_y
*
k
+
e1
]
>
tmp_dist
)
{
for
(
ptrdiff_t
k_idx_2
=
K
-
1
;
k_idx_2
>
k_idx_1
;
k_idx_2
--
)
{
for
(
int64_t
e2
=
k
-
1
;
e2
>
e1
;
e2
--
)
{
dist
[
n_y
*
K
+
k_idx_2
]
=
dist
[
n_y
*
K
+
k_idx_2
-
1
];
dist
[
n_y
*
k
+
e2
]
=
dist
[
n_y
*
k
+
e2
-
1
];
col
[
n_y
*
K
+
k_idx_2
]
=
col
[
n_y
*
K
+
k_idx_2
-
1
];
col
[
n_y
*
k
+
e2
]
=
col
[
n_y
*
k
+
e2
-
1
];
}
dist
[
n_y
*
K
+
k_idx_1
]
=
tmp_dist
;
col
[
n_y
*
K
+
k_idx_1
]
=
n_x
;
break
;
}
}
dist
[
n_y
*
k
+
e1
]
=
tmp_dist
;
col
[
n_y
*
k
+
e1
]
=
n_x
;
break
;
}
}
}
}
}
}
}
}
torch
::
Tensor
knn_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
knn_cuda
(
const
torch
::
Tensor
x
,
const
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
int64_t
k
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
const
int64_t
k
,
bool
cosine
)
{
const
bool
cosine
)
{
CHECK_CUDA
(
x
);
CHECK_CUDA
(
x
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_CUDA
(
y
);
CHECK_CUDA
(
y
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
cudaSetDevice
(
x
.
get_devic
e
());
CHECK_INPUT
(
x
.
size
(
1
)
==
y
.
siz
e
(
1
));
if
(
ptr_x
.
has_value
())
{
if
(
ptr_x
.
has_value
())
{
CHECK_CUDA
(
ptr_x
.
value
());
CHECK_CUDA
(
ptr_x
.
value
());
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
}
else
{
}
else
ptr_x
=
torch
::
arange
(
0
,
x
.
size
(
0
)
+
1
,
x
.
size
(
0
),
ptr_x
=
torch
::
arange
(
0
,
x
.
size
(
0
)
+
1
,
x
.
size
(
0
),
x
.
options
().
dtype
(
torch
::
kLong
));
x
.
options
().
dtype
(
torch
::
kLong
));
}
if
(
ptr_y
.
has_value
())
{
if
(
ptr_y
.
has_value
())
{
CHECK_CUDA
(
ptr_y
.
value
());
CHECK_CUDA
(
ptr_y
.
value
());
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
else
{
}
else
ptr_y
=
torch
::
arange
(
0
,
y
.
size
(
0
)
+
1
,
y
.
size
(
0
),
ptr_y
=
torch
::
arange
(
0
,
y
.
size
(
0
)
+
1
,
y
.
size
(
0
),
y
.
options
().
dtype
(
torch
::
kLong
));
y
.
options
().
dtype
(
torch
::
kLong
));
}
CHECK_INPUT
(
ptr_x
.
value
().
numel
()
==
ptr_y
.
value
().
numel
());
CHECK_INPUT
(
ptr_x
.
value
().
numel
()
==
ptr_y
.
value
().
numel
());
auto
dist
=
torch
::
full
(
y
.
size
(
0
)
*
k
,
1e38
,
y
.
options
());
cudaSetDevice
(
x
.
get_device
());
auto
dist
=
torch
::
full
(
y
.
size
(
0
)
*
k
,
1e10
,
y
.
options
());
auto
row
=
torch
::
empty
(
y
.
size
(
0
)
*
k
,
ptr_y
.
value
().
options
());
auto
row
=
torch
::
empty
(
y
.
size
(
0
)
*
k
,
ptr_y
.
value
().
options
());
auto
col
=
torch
::
full
(
y
.
size
(
0
)
*
k
,
-
1
,
ptr_y
.
value
().
options
());
auto
col
=
torch
::
full
(
y
.
size
(
0
)
*
k
,
-
1
,
ptr_y
.
value
().
options
());
dim3
BLOCKS
((
y
.
size
(
0
)
+
THREADS
-
1
)
/
THREADS
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"knn_kernel"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"knn_kernel"
,
[
&
]
{
knn_kernel
<
scalar_t
><<<
ptr_x
.
value
().
size
(
0
)
-
1
,
THREADS
,
0
,
stream
>>>
(
knn_kernel
<
scalar_t
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
dist
.
data_ptr
<
scalar_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
dist
.
data_ptr
<
scalar_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
k
,
x
.
size
(
1
),
cosine
);
col
.
data_ptr
<
int64_t
>
(),
k
,
x
.
size
(
0
),
y
.
size
(
0
),
x
.
size
(
1
),
ptr_x
.
value
().
numel
()
-
1
,
cosine
);
});
});
auto
mask
=
col
!=
-
1
;
auto
mask
=
col
!=
-
1
;
...
...
test/test_knn.py
View file @
0adaf7f9
...
@@ -71,8 +71,7 @@ def test_knn_graph(dtype, device):
...
@@ -71,8 +71,7 @@ def test_knn_graph(dtype, device):
def
test_knn_graph_large
(
dtype
,
device
):
def
test_knn_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
edge_index
=
knn_graph
(
x
,
k
=
5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
knn_graph
(
x
,
k
=
5
,
flow
=
'target_to_source'
,
loop
=
True
)
num_workers
=
6
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
_
,
col
=
tree
.
query
(
x
.
cpu
(),
k
=
5
)
_
,
col
=
tree
.
query
(
x
.
cpu
(),
k
=
5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment