Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
27b008b9
"src/libtorchaudio/sox/pybind/pybind.cpp" did not exist on "0076ab073d1ee6160efbc239e075196b35ed850b"
Unverified
Commit
27b008b9
authored
Apr 10, 2023
by
Tianqi Zhang (张天启)
Committed by
GitHub
Apr 10, 2023
Browse files
[BugFix] Fix unstable behavior of beuteforce-sharemem KNN (#5515)
parent
c51cc82e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
5 deletions
+59
-5
src/graph/transform/cuda/knn.cu
src/graph/transform/cuda/knn.cu
+31
-5
tests/python/pytorch/geometry/test_geometry.py
tests/python/pytorch/geometry/test_geometry.py
+28
-0
No files found.
src/graph/transform/cuda/knn.cu
View file @
27b008b9
...
...
@@ -12,6 +12,7 @@
#include <algorithm>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
#include "../../../array/cuda/dgl_cub.cuh"
...
...
@@ -22,6 +23,20 @@
namespace
dgl
{
namespace
transform
{
namespace
impl
{
/**
* @brief Given input `size`, find the smallest value
* greater or equal to `size` that is a multiple of `align`.
*
* e.g. Pow2Align(17, 4) = 20, Pow2Align(17, 8) = 24
*/
template
<
typename
Type
>
static
__host__
__device__
std
::
enable_if_t
<
std
::
is_unsigned
<
Type
>::
value
,
Type
>
Pow2Align
(
Type
size
,
Type
align
)
{
if
(
align
<=
1
||
size
<=
0
)
return
size
;
return
((
size
-
1
)
|
(
align
-
1
))
+
1
;
}
/**
* @brief Utility class used to avoid linker errors with extern
* unsized shared memory arrays with templated type
...
...
@@ -307,15 +322,19 @@ __global__ void BruteforceKnnShareKernel(
FloatType
*
data_buff
=
SharedMemory
<
FloatType
>
();
FloatType
*
query_buff
=
data_buff
+
block_size
*
feature_size
;
FloatType
*
dist_buff
=
query_buff
+
block_size
*
feature_size
;
IdType
*
res_buff
=
reinterpret_cast
<
IdType
*>
(
dist_buff
+
block_size
*
k
);
IdType
*
res_buff
=
reinterpret_cast
<
IdType
*>
(
Pow2Align
<
uint64_t
>
(
reinterpret_cast
<
uint64_t
>
(
dist_buff
+
block_size
*
k
),
sizeof
(
IdType
)));
FloatType
worst_dist
=
std
::
numeric_limits
<
FloatType
>::
max
();
// initialize dist buff with inf value
for
(
auto
i
=
0
;
i
<
k
;
++
i
)
{
dist_buff
[
threadIdx
.
x
*
k
+
i
]
=
std
::
numeric_limits
<
FloatType
>::
max
();
dist_buff
[
threadIdx
.
x
+
i
*
block_size
]
=
std
::
numeric_limits
<
FloatType
>::
max
();
}
// load query data to shared memory
// TODO(tianqi): could be better here to exploit coalesce global memory
// access.
if
(
query_idx
<
query_end
)
{
for
(
auto
i
=
0
;
i
<
feature_size
;
++
i
)
{
// to avoid bank conflict, we use transpose here
...
...
@@ -388,6 +407,7 @@ __global__ void BruteforceKnnShareKernel(
worst_dist
=
dist_buff
[
threadIdx
.
x
*
k
];
}
}
__syncthreads
();
}
// copy result to global memory
...
...
@@ -503,6 +523,7 @@ void BruteForceKNNSharedCuda(
const
FloatType
*
query_points_data
=
query_points
.
Ptr
<
FloatType
>
();
IdType
*
query_out
=
result
.
Ptr
<
IdType
>
();
IdType
*
data_out
=
query_out
+
k
*
query_points
->
shape
[
0
];
constexpr
size_t
smem_align
=
std
::
max
(
sizeof
(
IdType
),
sizeof
(
FloatType
));
// get max shared memory per block in bytes
// determine block size according to this value
...
...
@@ -510,8 +531,10 @@ void BruteForceKNNSharedCuda(
CUDA_CALL
(
cudaDeviceGetAttribute
(
&
max_sharedmem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlock
,
ctx
.
device_id
));
const
int64_t
single_shared_mem
=
(
k
+
2
*
feature_size
)
*
sizeof
(
FloatType
)
+
k
*
sizeof
(
IdType
);
const
int64_t
single_shared_mem
=
static_cast
<
int64_t
>
(
Pow2Align
<
size_t
>
(
(
k
+
2
*
feature_size
)
*
sizeof
(
FloatType
)
+
k
*
sizeof
(
IdType
),
smem_align
));
const
int64_t
block_size
=
cuda
::
FindNumThreads
(
max_sharedmem_per_block
/
single_shared_mem
);
...
...
@@ -538,6 +561,9 @@ void BruteForceKNNSharedCuda(
batch_size
,
stream
));
device
->
FreeWorkspace
(
ctx
,
prefix_temp
);
// wait for results
CUDA_CALL
(
cudaStreamSynchronize
(
stream
));
int64_t
num_blocks
=
0
,
final_elem
=
0
,
copyoffset
=
(
batch_size
-
1
)
*
sizeof
(
IdType
);
device
->
CopyDataFromTo
(
...
...
@@ -548,7 +574,6 @@ void BruteForceKNNSharedCuda(
DGLContext
{
kDGLCPU
,
0
},
query_offsets
->
dtype
);
num_blocks
+=
final_elem
;
device
->
FreeWorkspace
(
ctx
,
num_block_per_segment
);
device
->
FreeWorkspace
(
ctx
,
num_block_prefixsum
);
// get batch id and local id in segment
temp_block_size
=
cuda
::
FindNumThreads
(
num_blocks
);
...
...
@@ -570,6 +595,7 @@ void BruteForceKNNSharedCuda(
data_offsets_data
,
query_points_data
,
query_offsets_data
,
block_batch_id
,
local_block_id
,
k
,
dists
,
query_out
,
data_out
,
batch_size
,
feature_size
);
device
->
FreeWorkspace
(
ctx
,
num_block_prefixsum
);
device
->
FreeWorkspace
(
ctx
,
dists
);
device
->
FreeWorkspace
(
ctx
,
local_block_id
);
device
->
FreeWorkspace
(
ctx
,
block_batch_id
);
...
...
tests/python/pytorch/geometry/test_geometry.py
View file @
27b008b9
...
...
@@ -183,6 +183,33 @@ def test_knn_cuda(algorithm, dist, exclude_self):
_test_knn_common
(
F
.
cuda
(),
algorithm
,
dist
,
exclude_self
)
@
pytest
.
mark
.
parametrize
(
"num_points"
,
[
8
,
64
,
256
,
1024
])
def
test_knn_sharedmem_large
(
num_points
):
if
not
th
.
cuda
.
is_available
():
return
x
=
th
.
randn
(
num_points
,
5
,
device
=
"cuda"
)
y
=
th
.
randn
(
num_points
,
5
,
device
=
"cuda"
)
k
=
4
def
ground_truth
(
x
,
y
,
k
):
dist
=
(
th
.
sum
(
x
*
x
,
dim
=
1
)
+
th
.
sum
(
y
*
y
,
dim
=
1
).
unsqueeze
(
-
1
)
-
2
*
th
.
mm
(
y
,
x
.
T
)
)
ret
=
th
.
topk
(
dist
,
k
,
dim
=-
1
,
largest
=
False
)[
1
]
return
th
.
sort
(
ret
,
dim
=-
1
)[
0
]
gt
=
ground_truth
(
x
,
y
,
k
)
actual
=
th
.
sort
(
dgl
.
functional
.
knn
(
k
,
x
,
[
num_points
],
y
,
[
num_points
],
algorithm
=
"bruteforce-sharemem"
)[
1
].
reshape
(
-
1
,
k
),
-
1
,
)[
0
]
assert
th
.
all
(
actual
==
gt
).
item
()
@
parametrize_idtype
@
pytest
.
mark
.
parametrize
(
"g"
,
get_cases
([
"homo"
],
exclude
=
[
"dglgraph"
]))
@
pytest
.
mark
.
parametrize
(
"weight"
,
[
True
,
False
])
...
...
@@ -224,3 +251,4 @@ if __name__ == "__main__":
test_fps
()
test_fps_start_idx
()
test_knn
()
test_knn_sharedmem_large
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment