".github/vscode:/vscode.git/clone" did not exist on "493f9529d75b868da58bb16f711ed39c73f39b8c"
Unverified Commit 27b008b9 authored by Tianqi Zhang (张天启)'s avatar Tianqi Zhang (张天启) Committed by GitHub
Browse files

[BugFix] Fix unstable behavior of beuteforce-sharemem KNN (#5515)

parent c51cc82e
......@@ -12,6 +12,7 @@
#include <algorithm>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
#include "../../../array/cuda/dgl_cub.cuh"
......@@ -22,6 +23,20 @@
namespace dgl {
namespace transform {
namespace impl {
/**
* @brief Given input `size`, find the smallest value
* greater or equal to `size` that is a multiple of `align`.
*
* e.g. Pow2Align(17, 4) = 20, Pow2Align(17, 8) = 24
*/
template <typename Type>
static __host__ __device__ std::enable_if_t<std::is_unsigned<Type>::value, Type>
Pow2Align(Type size, Type align) {
if (align <= 1 || size <= 0) return size;
return ((size - 1) | (align - 1)) + 1;
}
/**
* @brief Utility class used to avoid linker errors with extern
* unsized shared memory arrays with templated type
......@@ -307,15 +322,19 @@ __global__ void BruteforceKnnShareKernel(
FloatType* data_buff = SharedMemory<FloatType>();
FloatType* query_buff = data_buff + block_size * feature_size;
FloatType* dist_buff = query_buff + block_size * feature_size;
IdType* res_buff = reinterpret_cast<IdType*>(dist_buff + block_size * k);
IdType* res_buff = reinterpret_cast<IdType*>(Pow2Align<uint64_t>(
reinterpret_cast<uint64_t>(dist_buff + block_size * k), sizeof(IdType)));
FloatType worst_dist = std::numeric_limits<FloatType>::max();
// initialize dist buff with inf value
for (auto i = 0; i < k; ++i) {
dist_buff[threadIdx.x * k + i] = std::numeric_limits<FloatType>::max();
dist_buff[threadIdx.x + i * block_size] =
std::numeric_limits<FloatType>::max();
}
// load query data to shared memory
// TODO(tianqi): could be better here to exploit coalesce global memory
// access.
if (query_idx < query_end) {
for (auto i = 0; i < feature_size; ++i) {
// to avoid bank conflict, we use transpose here
......@@ -388,6 +407,7 @@ __global__ void BruteforceKnnShareKernel(
worst_dist = dist_buff[threadIdx.x * k];
}
}
__syncthreads();
}
// copy result to global memory
......@@ -503,6 +523,7 @@ void BruteForceKNNSharedCuda(
const FloatType* query_points_data = query_points.Ptr<FloatType>();
IdType* query_out = result.Ptr<IdType>();
IdType* data_out = query_out + k * query_points->shape[0];
constexpr size_t smem_align = std::max(sizeof(IdType), sizeof(FloatType));
// get max shared memory per block in bytes
// determine block size according to this value
......@@ -510,8 +531,10 @@ void BruteForceKNNSharedCuda(
CUDA_CALL(cudaDeviceGetAttribute(
&max_sharedmem_per_block, cudaDevAttrMaxSharedMemoryPerBlock,
ctx.device_id));
const int64_t single_shared_mem =
(k + 2 * feature_size) * sizeof(FloatType) + k * sizeof(IdType);
const int64_t single_shared_mem = static_cast<int64_t>(Pow2Align<size_t>(
(k + 2 * feature_size) * sizeof(FloatType) + k * sizeof(IdType),
smem_align));
const int64_t block_size =
cuda::FindNumThreads(max_sharedmem_per_block / single_shared_mem);
......@@ -538,6 +561,9 @@ void BruteForceKNNSharedCuda(
batch_size, stream));
device->FreeWorkspace(ctx, prefix_temp);
// wait for results
CUDA_CALL(cudaStreamSynchronize(stream));
int64_t num_blocks = 0, final_elem = 0,
copyoffset = (batch_size - 1) * sizeof(IdType);
device->CopyDataFromTo(
......@@ -548,7 +574,6 @@ void BruteForceKNNSharedCuda(
DGLContext{kDGLCPU, 0}, query_offsets->dtype);
num_blocks += final_elem;
device->FreeWorkspace(ctx, num_block_per_segment);
device->FreeWorkspace(ctx, num_block_prefixsum);
// get batch id and local id in segment
temp_block_size = cuda::FindNumThreads(num_blocks);
......@@ -570,6 +595,7 @@ void BruteForceKNNSharedCuda(
data_offsets_data, query_points_data, query_offsets_data, block_batch_id,
local_block_id, k, dists, query_out, data_out, batch_size, feature_size);
device->FreeWorkspace(ctx, num_block_prefixsum);
device->FreeWorkspace(ctx, dists);
device->FreeWorkspace(ctx, local_block_id);
device->FreeWorkspace(ctx, block_batch_id);
......
......@@ -183,6 +183,33 @@ def test_knn_cuda(algorithm, dist, exclude_self):
_test_knn_common(F.cuda(), algorithm, dist, exclude_self)
@pytest.mark.parametrize("num_points", [8, 64, 256, 1024])
def test_knn_sharedmem_large(num_points):
if not th.cuda.is_available():
return
x = th.randn(num_points, 5, device="cuda")
y = th.randn(num_points, 5, device="cuda")
k = 4
def ground_truth(x, y, k):
dist = (
th.sum(x * x, dim=1)
+ th.sum(y * y, dim=1).unsqueeze(-1)
- 2 * th.mm(y, x.T)
)
ret = th.topk(dist, k, dim=-1, largest=False)[1]
return th.sort(ret, dim=-1)[0]
gt = ground_truth(x, y, k)
actual = th.sort(
dgl.functional.knn(
k, x, [num_points], y, [num_points], algorithm="bruteforce-sharemem"
)[1].reshape(-1, k),
-1,
)[0]
assert th.all(actual == gt).item()
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["dglgraph"]))
@pytest.mark.parametrize("weight", [True, False])
......@@ -224,3 +251,4 @@ if __name__ == "__main__":
test_fps()
test_fps_start_idx()
test_knn()
test_knn_sharedmem_large()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment