Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
0d735d7e
Commit
0d735d7e
authored
Dec 14, 2021
by
rusty1s
Browse files
improve radius performance
parent
0adaf7f9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
71 deletions
+88
-71
csrc/cpu/utils.h
csrc/cpu/utils.h
+2
-1
csrc/cuda/knn_cuda.cu
csrc/cuda/knn_cuda.cu
+26
-25
csrc/cuda/radius_cuda.cu
csrc/cuda/radius_cuda.cu
+48
-44
csrc/cuda/utils.cuh
csrc/cuda/utils.cuh
+11
-0
test/test_radius.py
test/test_radius.py
+1
-1
No files found.
csrc/cpu/utils.h
View file @
0d735d7e
...
@@ -4,4 +4,5 @@
...
@@ -4,4 +4,5 @@
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
csrc/cuda/knn_cuda.cu
View file @
0d735d7e
...
@@ -27,33 +27,28 @@ template <typename scalar_t> struct Cosine {
...
@@ -27,33 +27,28 @@ template <typename scalar_t> struct Cosine {
}
}
};
};
__device__
int64_t
get_example_idx
(
int64_t
idx
,
const
int64_t
*
ptr
,
const
int64_t
num_examples
)
{
for
(
int64_t
i
=
0
;
i
<
num_examples
;
i
++
)
{
if
(
ptr
[
i
+
1
]
>
idx
)
return
i
;
}
return
num_examples
-
1
;
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
__global__
void
knn_kernel
(
const
scalar_t
*
__restrict__
x
,
const
scalar_t
*
__restrict__
y
,
knn_kernel
(
const
scalar_t
*
__restrict__
x
,
const
scalar_t
*
__restrict__
y
,
const
int64_t
*
__restrict__
ptr_x
,
const
int64_t
*
__restrict__
ptr_y
,
const
int64_t
*
__restrict__
ptr_x
,
const
int64_t
*
__restrict__
ptr_y
,
scalar_t
*
__restrict__
dist
,
int64_t
*
__restrict__
row
,
int64_t
*
__restrict__
row
,
int64_t
*
__restrict__
col
,
int64_t
*
__restrict__
col
,
const
int64_t
k
,
const
int64_t
n
,
const
int64_t
k
,
const
int64_t
n
,
const
int64_t
m
,
const
int64_t
dim
,
const
int64_t
m
,
const
int64_t
dim
,
const
int64_t
num_examples
,
const
int64_t
num_examples
,
const
bool
cosine
)
{
const
bool
cosine
)
{
const
int64_t
n_y
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int64_t
n_y
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
n_y
>=
m
)
if
(
n_y
>=
m
)
return
;
return
;
for
(
int64_t
e
=
0
;
e
<
k
;
e
++
)
row
[
n_y
*
k
+
e
]
=
n_y
;
const
int64_t
example_idx
=
get_example_idx
(
n_y
,
ptr_y
,
num_examples
);
const
int64_t
example_idx
=
get_example_idx
(
n_y
,
ptr_y
,
num_examples
);
scalar_t
best_dist
[
100
];
int64_t
best_idx
[
100
];
for
(
int
e
=
0
;
e
<
k
;
e
++
)
{
best_dist
[
e
]
=
1e10
;
best_idx
[
e
]
=
-
1
;
}
for
(
int64_t
n_x
=
ptr_x
[
example_idx
];
n_x
<
ptr_x
[
example_idx
+
1
];
n_x
++
)
{
for
(
int64_t
n_x
=
ptr_x
[
example_idx
];
n_x
<
ptr_x
[
example_idx
+
1
];
n_x
++
)
{
scalar_t
tmp_dist
=
0
;
scalar_t
tmp_dist
=
0
;
...
@@ -70,17 +65,22 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
...
@@ -70,17 +65,22 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
}
}
for
(
int64_t
e1
=
0
;
e1
<
k
;
e1
++
)
{
for
(
int64_t
e1
=
0
;
e1
<
k
;
e1
++
)
{
if
(
dist
[
n_y
*
k
+
e1
]
>
tmp_dist
)
{
if
(
best_
dist
[
e1
]
>
tmp_dist
)
{
for
(
int64_t
e2
=
k
-
1
;
e2
>
e1
;
e2
--
)
{
for
(
int64_t
e2
=
k
-
1
;
e2
>
e1
;
e2
--
)
{
dist
[
n_y
*
k
+
e2
]
=
dist
[
n_y
*
k
+
e2
-
1
];
best_dist
[
e2
]
=
best_dist
[
e2
-
1
];
col
[
n_y
*
k
+
e2
]
=
col
[
n_y
*
k
+
e2
-
1
];
best_idx
[
e2
]
=
best_idx
[
e2
-
1
];
}
}
dist
[
n_y
*
k
+
e1
]
=
tmp_dist
;
best_
dist
[
e1
]
=
tmp_dist
;
col
[
n_y
*
k
+
e1
]
=
n_x
;
best_idx
[
e1
]
=
n_x
;
break
;
break
;
}
}
}
}
}
}
for
(
int64_t
e
=
0
;
e
<
k
;
e
++
)
{
row
[
n_y
*
k
+
e
]
=
n_y
;
col
[
n_y
*
k
+
e
]
=
best_idx
[
e
];
}
}
}
torch
::
Tensor
knn_cuda
(
const
torch
::
Tensor
x
,
const
torch
::
Tensor
y
,
torch
::
Tensor
knn_cuda
(
const
torch
::
Tensor
x
,
const
torch
::
Tensor
y
,
...
@@ -89,10 +89,13 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
...
@@ -89,10 +89,13 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
const
bool
cosine
)
{
const
bool
cosine
)
{
CHECK_CUDA
(
x
);
CHECK_CUDA
(
x
);
CHECK_CONTIGUOUS
(
x
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_CUDA
(
y
);
CHECK_CUDA
(
y
);
CHECK_CONTIGUOUS
(
y
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
CHECK_INPUT
(
x
.
size
(
1
)
==
y
.
size
(
1
));
CHECK_INPUT
(
x
.
size
(
1
)
==
y
.
size
(
1
));
AT_ASSERTM
(
k
<=
100
,
"`k` needs to smaller than or equal to 100"
);
if
(
ptr_x
.
has_value
())
{
if
(
ptr_x
.
has_value
())
{
CHECK_CUDA
(
ptr_x
.
value
());
CHECK_CUDA
(
ptr_x
.
value
());
...
@@ -112,7 +115,6 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
...
@@ -112,7 +115,6 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
cudaSetDevice
(
x
.
get_device
());
cudaSetDevice
(
x
.
get_device
());
auto
dist
=
torch
::
full
(
y
.
size
(
0
)
*
k
,
1e10
,
y
.
options
());
auto
row
=
torch
::
empty
(
y
.
size
(
0
)
*
k
,
ptr_y
.
value
().
options
());
auto
row
=
torch
::
empty
(
y
.
size
(
0
)
*
k
,
ptr_y
.
value
().
options
());
auto
col
=
torch
::
full
(
y
.
size
(
0
)
*
k
,
-
1
,
ptr_y
.
value
().
options
());
auto
col
=
torch
::
full
(
y
.
size
(
0
)
*
k
,
-
1
,
ptr_y
.
value
().
options
());
...
@@ -123,9 +125,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
...
@@ -123,9 +125,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
knn_kernel
<
scalar_t
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
knn_kernel
<
scalar_t
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
dist
.
data_ptr
<
scalar_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
k
,
x
.
size
(
0
),
col
.
data_ptr
<
int64_t
>
(),
k
,
x
.
size
(
0
),
y
.
size
(
0
),
x
.
size
(
1
),
y
.
size
(
0
),
x
.
size
(
1
),
ptr_x
.
value
().
numel
()
-
1
,
cosine
);
ptr_x
.
value
().
numel
()
-
1
,
cosine
);
});
});
auto
mask
=
col
!=
-
1
;
auto
mask
=
col
!=
-
1
;
...
...
csrc/cuda/radius_cuda.cu
View file @
0d735d7e
...
@@ -4,84 +4,88 @@
...
@@ -4,84 +4,88 @@
#include "utils.cuh"
#include "utils.cuh"
#define THREADS
1024
#define THREADS
256
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
radius_kernel
(
const
scalar_t
*
x
,
const
scalar_t
*
y
,
__global__
void
const
int64_t
*
ptr_x
,
const
int64_t
*
ptr_y
,
radius_kernel
(
const
scalar_t
*
__restrict__
x
,
const
scalar_t
*
__restrict__
y
,
int64_t
*
row
,
int64_t
*
col
,
scalar_t
radius
,
const
int64_t
*
__restrict__
ptr_x
,
int64_t
max_num_neighbors
,
int64_t
dim
)
{
const
int64_t
*
__restrict__
ptr_y
,
int64_t
*
__restrict__
row
,
int64_t
*
__restrict__
col
,
const
scalar_t
r
,
const
int64_t
n
,
const
int64_t
batch_idx
=
blockIdx
.
x
;
const
int64_t
m
,
const
int64_t
dim
,
const
int64_t
num_examples
,
const
int64_t
max_num_neighbors
)
{
const
int64_t
x_start_idx
=
ptr_x
[
batch_idx
];
const
int64_t
x_end_idx
=
ptr_x
[
batch_idx
+
1
];
const
int64_t
n_y
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
n_y
>=
m
)
const
int64_t
y_start_idx
=
ptr_y
[
batch_idx
];
return
;
const
int64_t
y_end_idx
=
ptr_y
[
batch_idx
+
1
];
int64_t
count
=
0
;
for
(
int64_t
n_y
=
y_start_idx
+
threadIdx
.
x
;
n_y
<
y_end_idx
;
const
int64_t
example_idx
=
get_example_idx
(
n_y
,
ptr_y
,
num_examples
);
n_y
+=
THREADS
)
{
int64_t
count
=
0
;
for
(
int64_t
n_x
=
ptr_x
[
example_idx
];
n_x
<
ptr_x
[
example_idx
+
1
];
n_x
++
)
{
for
(
int64_t
n_x
=
x_start_idx
;
n_x
<
x_end_idx
;
n_x
++
)
{
scalar_t
dist
=
0
;
scalar_t
dist
=
0
;
for
(
int64_t
d
=
0
;
d
<
dim
;
d
++
)
{
for
(
int64_t
d
=
0
;
d
<
dim
;
d
++
)
{
dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
]);
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
]);
}
dist
=
sqrt
(
dist
);
if
(
dist
<
radius
)
{
row
[
n_y
*
max_num_neighbors
+
count
]
=
n_y
;
col
[
n_y
*
max_num_neighbors
+
count
]
=
n_x
;
count
++
;
}
if
(
count
>=
max_num_neighbors
)
{
break
;
}
}
}
if
(
dist
<
r
)
{
row
[
n_y
*
max_num_neighbors
+
count
]
=
n_y
;
col
[
n_y
*
max_num_neighbors
+
count
]
=
n_x
;
count
++
;
}
if
(
count
>=
max_num_neighbors
)
break
;
}
}
}
}
torch
::
Tensor
radius_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
radius_cuda
(
const
torch
::
Tensor
x
,
const
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
double
r
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
const
double
r
,
int64_t
max_num_neighbors
)
{
const
int64_t
max_num_neighbors
)
{
CHECK_CUDA
(
x
);
CHECK_CUDA
(
x
);
CHECK_CONTIGUOUS
(
x
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_CUDA
(
y
);
CHECK_CUDA
(
y
);
CHECK_CONTIGUOUS
(
y
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
CHECK_INPUT
(
x
.
size
(
1
)
==
y
.
size
(
1
));
cudaSetDevice
(
x
.
get_device
());
cudaSetDevice
(
x
.
get_device
());
if
(
ptr_x
.
has_value
())
{
if
(
ptr_x
.
has_value
())
{
CHECK_CUDA
(
ptr_x
.
value
());
CHECK_CUDA
(
ptr_x
.
value
());
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
}
else
{
}
else
ptr_x
=
torch
::
arange
(
0
,
x
.
size
(
0
)
+
1
,
x
.
size
(
0
),
ptr_x
=
torch
::
arange
(
0
,
x
.
size
(
0
)
+
1
,
x
.
size
(
0
),
x
.
options
().
dtype
(
torch
::
kLong
));
x
.
options
().
dtype
(
torch
::
kLong
));
}
if
(
ptr_y
.
has_value
())
{
if
(
ptr_y
.
has_value
())
{
CHECK_CUDA
(
ptr_y
.
value
());
CHECK_CUDA
(
ptr_y
.
value
());
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
else
{
}
else
ptr_y
=
torch
::
arange
(
0
,
y
.
size
(
0
)
+
1
,
y
.
size
(
0
),
ptr_y
=
torch
::
arange
(
0
,
y
.
size
(
0
)
+
1
,
y
.
size
(
0
),
y
.
options
().
dtype
(
torch
::
kLong
));
y
.
options
().
dtype
(
torch
::
kLong
));
}
CHECK_INPUT
(
ptr_x
.
value
().
numel
()
==
ptr_y
.
value
().
numel
());
CHECK_INPUT
(
ptr_x
.
value
().
numel
()
==
ptr_y
.
value
().
numel
());
cudaSetDevice
(
x
.
get_device
());
auto
row
=
auto
row
=
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
value
().
options
());
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
value
().
options
());
auto
col
=
auto
col
=
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
value
().
options
());
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
value
().
options
());
dim3
BLOCKS
((
y
.
size
(
0
)
+
THREADS
-
1
)
/
THREADS
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"radius_kernel"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"radius_kernel"
,
[
&
]
{
radius_kernel
<
scalar_t
><<<
ptr_x
.
value
().
size
(
0
)
-
1
,
THREADS
,
0
,
stream
>>>
(
radius_kernel
<
scalar_t
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
r
,
max_num_neighbors
,
row
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
r
*
r
,
x
.
size
(
0
)
,
x
.
size
(
1
)
);
y
.
size
(
0
),
x
.
size
(
1
),
ptr_x
.
value
().
numel
()
-
1
,
max_num_neighbors
);
});
});
auto
mask
=
row
!=
-
1
;
auto
mask
=
row
!=
-
1
;
...
...
csrc/cuda/utils.cuh
View file @
0d735d7e
...
@@ -5,3 +5,14 @@
...
@@ -5,3 +5,14 @@
#define CHECK_CUDA(x) \
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
__device__
int64_t
get_example_idx
(
int64_t
idx
,
const
int64_t
*
ptr
,
const
int64_t
num_examples
)
{
for
(
int64_t
i
=
0
;
i
<
num_examples
;
i
++
)
{
if
(
ptr
[
i
+
1
]
>
idx
)
return
i
;
}
return
num_examples
-
1
;
}
test/test_radius.py
View file @
0d735d7e
...
@@ -71,7 +71,7 @@ def test_radius_graph_large(dtype, device):
...
@@ -71,7 +71,7 @@ def test_radius_graph_large(dtype, device):
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
edge_index
=
radius_graph
(
x
,
r
=
0.5
,
flow
=
'target_to_source'
,
loop
=
True
,
max_num_neighbors
=
2000
,
num_workers
=
6
)
max_num_neighbors
=
2000
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
cpu
().
numpy
())
col
=
tree
.
query_ball_point
(
x
.
cpu
(),
r
=
0.5
)
col
=
tree
.
query_ball_point
(
x
.
cpu
(),
r
=
0.5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment