#include "hip/hip_runtime.h" /*! * Copyright (c) 2020 by Contributors * \file graph/transform/cuda/knn.cu * \brief k-nearest-neighbor (KNN) implementation (cuda) */ #include #include #include #include #include #include #include #include #include "../../../array/cuda/dgl_cub.cuh" #include "../../../runtime/cuda/cuda_common.h" #include "../../../array/cuda/utils.h" #include "../knn.h" namespace dgl { namespace transform { namespace impl { /*! * \brief Utility class used to avoid linker errors with extern * unsized shared memory arrays with templated type */ template struct SharedMemory { __device__ inline operator Type* () { extern __shared__ int __smem[]; return reinterpret_cast(__smem); } __device__ inline operator const Type* () const { extern __shared__ int __smem[]; return reinterpret_cast(__smem); } }; // specialize for double to avoid unaligned memory // access compile errors template <> struct SharedMemory { __device__ inline operator double* () { extern __shared__ double __smem_d[]; return reinterpret_cast(__smem_d); } __device__ inline operator const double* () const { extern __shared__ double __smem_d[]; return reinterpret_cast(__smem_d); } }; /*! \brief Compute Euclidean distance between two vectors in a cuda kernel */ template __device__ FloatType EuclideanDist(const FloatType* vec1, const FloatType* vec2, const int64_t dim) { FloatType dist = 0; IdType idx = 0; for (; idx < dim - 3; idx += 4) { FloatType diff0 = vec1[idx] - vec2[idx]; FloatType diff1 = vec1[idx + 1] - vec2[idx + 1]; FloatType diff2 = vec1[idx + 2] - vec2[idx + 2]; FloatType diff3 = vec1[idx + 3] - vec2[idx + 3]; dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; } for (; idx < dim; ++idx) { FloatType diff = vec1[idx] - vec2[idx]; dist += diff * diff; } return dist; } /*! * \brief Compute Euclidean distance between two vectors in a cuda kernel, * return positive infinite value if the intermediate distance is greater * than the worst distance. */ template __device__ FloatType EuclideanDistWithCheck(const FloatType* vec1, const FloatType* vec2, const int64_t dim, const FloatType worst_dist) { FloatType dist = 0; IdType idx = 0; bool early_stop = false; for (; idx < dim - 3; idx += 4) { FloatType diff0 = vec1[idx] - vec2[idx]; FloatType diff1 = vec1[idx + 1] - vec2[idx + 1]; FloatType diff2 = vec1[idx + 2] - vec2[idx + 2]; FloatType diff3 = vec1[idx + 3] - vec2[idx + 3]; dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; if (dist > worst_dist) { early_stop = true; idx = dim; break; } } for (; idx < dim; ++idx) { FloatType diff = vec1[idx] - vec2[idx]; dist += diff * diff; if (dist > worst_dist) { early_stop = true; break; } } if (early_stop) { return std::numeric_limits::max(); } else { return dist; } } template __device__ void BuildHeap(IdType* indices, FloatType* dists, int size) { for (int i = size / 2 - 1; i >= 0; --i) { IdType idx = i; while (true) { IdType largest = idx; IdType left = idx * 2 + 1; IdType right = left + 1; if (left < size && dists[left] > dists[largest]) { largest = left; } if (right < size && dists[right] > dists[largest]) { largest = right; } if (largest != idx) { IdType tmp_idx = indices[largest]; indices[largest] = indices[idx]; indices[idx] = tmp_idx; FloatType tmp_dist = dists[largest]; dists[largest] = dists[idx]; dists[idx] = tmp_dist; idx = largest; } else { break; } } } } template __device__ void HeapInsert(IdType* indices, FloatType* dist, IdType new_idx, FloatType new_dist, int size, bool check_repeat = false) { if (new_dist > dist[0]) return; // check if we have it if (check_repeat) { for (IdType i = 0; i < size; ++i) { if (indices[i] == new_idx) return; } } IdType left = 0, right = 0, idx = 0, largest = 0; dist[0] = new_dist; indices[0] = new_idx; while (true) { left = idx * 2 + 1; right = left + 1; if (left < size && dist[left] > dist[largest]) { largest = left; } if (right < size && dist[right] > dist[largest]) { largest = right; } if (largest != idx) { IdType tmp_idx = indices[idx]; indices[idx] = indices[largest]; indices[largest] = tmp_idx; FloatType tmp_dist = dist[idx]; dist[idx] = dist[largest]; dist[largest] = tmp_dist; idx = largest; } else { break; } } } template __device__ bool FlaggedHeapInsert(IdType* indices, FloatType* dist, bool* flags, IdType new_idx, FloatType new_dist, bool new_flag, int size, bool check_repeat = false) { if (new_dist > dist[0]) return false; // check if we have it if (check_repeat) { for (IdType i = 0; i < size; ++i) { if (indices[i] == new_idx) return false; } } IdType left = 0, right = 0, idx = 0, largest = 0; dist[0] = new_dist; indices[0] = new_idx; flags[0] = new_flag; while (true) { left = idx * 2 + 1; right = left + 1; if (left < size && dist[left] > dist[largest]) { largest = left; } if (right < size && dist[right] > dist[largest]) { largest = right; } if (largest != idx) { IdType tmp_idx = indices[idx]; indices[idx] = indices[largest]; indices[largest] = tmp_idx; FloatType tmp_dist = dist[idx]; dist[idx] = dist[largest]; dist[largest] = tmp_dist; bool tmp_flag = flags[idx]; flags[idx] = flags[largest]; flags[largest] = tmp_flag; idx = largest; } else { break; } } return true; } /*! * \brief Brute force kNN kernel. Compute distance for each pair of input points and get * the result directly (without a distance matrix). */ template __global__ void BruteforceKnnKernel(const FloatType* data_points, const IdType* data_offsets, const FloatType* query_points, const IdType* query_offsets, const int k, FloatType* dists, IdType* query_out, IdType* data_out, const int64_t num_batches, const int64_t feature_size) { const IdType q_idx = blockIdx.x * blockDim.x + threadIdx.x; if (q_idx >= query_offsets[num_batches]) return; IdType batch_idx = 0; for (IdType b = 0; b < num_batches + 1; ++b) { if (query_offsets[b] > q_idx) { batch_idx = b - 1; break; } } const IdType data_start = data_offsets[batch_idx], data_end = data_offsets[batch_idx + 1]; for (IdType k_idx = 0; k_idx < k; ++k_idx) { query_out[q_idx * k + k_idx] = q_idx; dists[q_idx * k + k_idx] = std::numeric_limits::max(); } FloatType worst_dist = std::numeric_limits::max(); for (IdType d_idx = data_start; d_idx < data_end; ++d_idx) { FloatType tmp_dist = EuclideanDistWithCheck( query_points + q_idx * feature_size, data_points + d_idx * feature_size, feature_size, worst_dist); IdType out_offset = q_idx * k; HeapInsert(data_out + out_offset, dists + out_offset, d_idx, tmp_dist, k); worst_dist = dists[q_idx * k]; } } /*! * \brief Same as BruteforceKnnKernel, but use shared memory as buffer. * This kernel divides query points and data points into blocks. For each * query block, it will make a loop over all data blocks and compute distances. * This kernel is faster when the dimension of input points is not large. */ template __global__ void BruteforceKnnShareKernel(const FloatType* data_points, const IdType* data_offsets, const FloatType* query_points, const IdType* query_offsets, const IdType* block_batch_id, const IdType* local_block_id, const int k, FloatType* dists, IdType* query_out, IdType* data_out, const int64_t num_batches, const int64_t feature_size) { const IdType block_idx = static_cast(blockIdx.x); const IdType block_size = static_cast(blockDim.x); const IdType batch_idx = block_batch_id[block_idx]; const IdType local_bid = local_block_id[block_idx]; const IdType query_start = query_offsets[batch_idx] + block_size * local_bid; const IdType query_end = min(query_start + block_size, query_offsets[batch_idx + 1]); if (query_start >= query_end) return; const IdType query_idx = query_start + threadIdx.x; const IdType data_start = data_offsets[batch_idx]; const IdType data_end = data_offsets[batch_idx + 1]; // shared memory: points in block + distance buffer + result buffer FloatType* data_buff = SharedMemory(); FloatType* query_buff = data_buff + block_size * feature_size; FloatType* dist_buff = query_buff + block_size * feature_size; IdType* res_buff = reinterpret_cast(dist_buff + block_size * k); FloatType worst_dist = std::numeric_limits::max(); // initialize dist buff with inf value for (auto i = 0; i < k; ++i) { dist_buff[threadIdx.x * k + i] = std::numeric_limits::max(); } // load query data to shared memory if (query_idx < query_end) { for (auto i = 0; i < feature_size; ++i) { // to avoid bank conflict, we use transpose here query_buff[threadIdx.x + i * block_size] = query_points[query_idx * feature_size + i]; } } // perform computation on each tile for (auto tile_start = data_start; tile_start < data_end; tile_start += block_size) { // each thread load one data point into the shared memory IdType load_idx = tile_start + threadIdx.x; if (load_idx < data_end) { for (auto i = 0; i < feature_size; ++i) { data_buff[threadIdx.x * feature_size + i] = data_points[load_idx * feature_size + i]; } } __syncthreads(); // compute distance for one tile IdType true_block_size = min(data_end - tile_start, block_size); if (query_idx < query_end) { for (IdType d_idx = 0; d_idx < true_block_size; ++d_idx) { FloatType tmp_dist = 0; bool early_stop = false; IdType dim_idx = 0; for (; dim_idx < feature_size - 3; dim_idx += 4) { FloatType diff0 = query_buff[threadIdx.x + block_size * (dim_idx)] - data_buff[d_idx * feature_size + dim_idx]; FloatType diff1 = query_buff[threadIdx.x + block_size * (dim_idx + 1)] - data_buff[d_idx * feature_size + dim_idx + 1]; FloatType diff2 = query_buff[threadIdx.x + block_size * (dim_idx + 2)] - data_buff[d_idx * feature_size + dim_idx + 2]; FloatType diff3 = query_buff[threadIdx.x + block_size * (dim_idx + 3)] - data_buff[d_idx * feature_size + dim_idx + 3]; tmp_dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; if (tmp_dist > worst_dist) { early_stop = true; dim_idx = feature_size; break; } } for (; dim_idx < feature_size; ++dim_idx) { const FloatType diff = query_buff[threadIdx.x + dim_idx * block_size] - data_buff[d_idx * feature_size + dim_idx]; tmp_dist += diff * diff; if (tmp_dist > worst_dist) { early_stop = true; break; } } if (early_stop) continue; HeapInsert( res_buff + threadIdx.x * k, dist_buff + threadIdx.x * k, d_idx + tile_start, tmp_dist, k); worst_dist = dist_buff[threadIdx.x * k]; } } } // copy result to global memory if (query_idx < query_end) { for (auto i = 0; i < k; ++i) { dists[query_idx * k + i] = dist_buff[threadIdx.x * k + i]; data_out[query_idx * k + i] = res_buff[threadIdx.x * k + i]; query_out[query_idx * k + i] = query_idx; } } } /*! \brief determine the number of blocks for each segment */ template __global__ void GetNumBlockPerSegment(const IdType* offsets, IdType* out, const int64_t batch_size, const int64_t block_size) { const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < batch_size) { out[idx] = (offsets[idx + 1] - offsets[idx] - 1) / block_size + 1; } } /*! \brief Get the batch index and local index in segment for each block */ template __global__ void GetBlockInfo(const IdType* num_block_prefixsum, IdType* block_batch_id, IdType* local_block_id, size_t batch_size, size_t num_blocks) { const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; IdType i = 0; if (idx < num_blocks) { for (; i < batch_size; ++i) { if (num_block_prefixsum[i] > idx) break; } i--; block_batch_id[idx] = i; local_block_id[idx] = idx - num_block_prefixsum[i]; } } /*! * \brief Brute force kNN. Compute distance for each pair of input points and get * the result directly (without a distance matrix). * * \tparam FloatType The type of input points. * \tparam IdType The type of id. * \param data_points NDArray of dataset points. * \param data_offsets offsets of point index in data points. * \param query_points NDArray of query points * \param query_offsets offsets of point index in query points. * \param k the number of nearest points * \param result output array */ template void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k, IdArray result) { hipStream_t stream = runtime::getCurrentCUDAStream(); const auto& ctx = data_points->ctx; auto device = runtime::DeviceAPI::Get(ctx); const int64_t batch_size = data_offsets->shape[0] - 1; const int64_t feature_size = data_points->shape[1]; const IdType* data_offsets_data = data_offsets.Ptr(); const IdType* query_offsets_data = query_offsets.Ptr(); const FloatType* data_points_data = data_points.Ptr(); const FloatType* query_points_data = query_points.Ptr(); IdType* query_out = result.Ptr(); IdType* data_out = query_out + k * query_points->shape[0]; FloatType* dists = static_cast(device->AllocWorkspace( ctx, k * query_points->shape[0] * sizeof(FloatType))); const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]); const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1; CUDA_KERNEL_CALL(BruteforceKnnKernel, num_blocks, block_size, 0, stream, data_points_data, data_offsets_data, query_points_data, query_offsets_data, k, dists, query_out, data_out, batch_size, feature_size); device->FreeWorkspace(ctx, dists); } /*! * \brief Brute force kNN with shared memory. * This function divides query points and data points into blocks. For each * query block, it will make a loop over all data blocks and compute distances. * It will be faster when the dimension of input points is not large. * * \tparam FloatType The type of input points. * \tparam IdType The type of id. * \param data_points NDArray of dataset points. * \param data_offsets offsets of point index in data points. * \param query_points NDArray of query points * \param query_offsets offsets of point index in query points. * \param k the number of nearest points * \param result output array */ template void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k, IdArray result) { hipStream_t stream = runtime::getCurrentCUDAStream(); const auto& ctx = data_points->ctx; auto device = runtime::DeviceAPI::Get(ctx); const int64_t batch_size = data_offsets->shape[0] - 1; const int64_t feature_size = data_points->shape[1]; const IdType* data_offsets_data = data_offsets.Ptr(); const IdType* query_offsets_data = query_offsets.Ptr(); const FloatType* data_points_data = data_points.Ptr(); const FloatType* query_points_data = query_points.Ptr(); IdType* query_out = result.Ptr(); IdType* data_out = query_out + k * query_points->shape[0]; // get max shared memory per block in bytes // determine block size according to this value int max_sharedmem_per_block = 0; CUDA_CALL(hipDeviceGetAttribute( &max_sharedmem_per_block, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id)); const int64_t single_shared_mem = (k + 2 * feature_size) * sizeof(FloatType) + k * sizeof(IdType); const int64_t block_size = cuda::FindNumThreads(max_sharedmem_per_block / single_shared_mem); // Determine the number of blocks. We first get the number of blocks for each // segment. Then we get the block id offset via prefix sum. IdType* num_block_per_segment = static_cast( device->AllocWorkspace(ctx, batch_size * sizeof(IdType))); IdType* num_block_prefixsum = static_cast( device->AllocWorkspace(ctx, batch_size * sizeof(IdType))); // block size for GetNumBlockPerSegment computation int64_t temp_block_size = cuda::FindNumThreads(batch_size); int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1; CUDA_KERNEL_CALL(GetNumBlockPerSegment, temp_num_blocks, temp_block_size, 0, stream, query_offsets_data, num_block_per_segment, batch_size, block_size); size_t prefix_temp_size = 0; CUDA_CALL(hipcub::DeviceScan::ExclusiveSum( nullptr, prefix_temp_size, num_block_per_segment, num_block_prefixsum, batch_size, stream)); void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size); CUDA_CALL(hipcub::DeviceScan::ExclusiveSum( prefix_temp, prefix_temp_size, num_block_per_segment, num_block_prefixsum, batch_size, stream)); device->FreeWorkspace(ctx, prefix_temp); int64_t num_blocks = 0, final_elem = 0, copyoffset = (batch_size - 1) * sizeof(IdType); device->CopyDataFromTo( num_block_prefixsum, copyoffset, &num_blocks, 0, sizeof(IdType), ctx, DLContext{kDLCPU, 0}, query_offsets->dtype); device->CopyDataFromTo( num_block_per_segment, copyoffset, &final_elem, 0, sizeof(IdType), ctx, DLContext{kDLCPU, 0}, query_offsets->dtype); num_blocks += final_elem; device->FreeWorkspace(ctx, num_block_per_segment); device->FreeWorkspace(ctx, num_block_prefixsum); // get batch id and local id in segment temp_block_size = cuda::FindNumThreads(num_blocks); temp_num_blocks = (num_blocks - 1) / temp_block_size + 1; IdType* block_batch_id = static_cast(device->AllocWorkspace( ctx, num_blocks * sizeof(IdType))); IdType* local_block_id = static_cast(device->AllocWorkspace( ctx, num_blocks * sizeof(IdType))); CUDA_KERNEL_CALL( GetBlockInfo, temp_num_blocks, temp_block_size, 0, stream, num_block_prefixsum, block_batch_id, local_block_id, batch_size, num_blocks); FloatType* dists = static_cast(device->AllocWorkspace( ctx, k * query_points->shape[0] * sizeof(FloatType))); CUDA_KERNEL_CALL(BruteforceKnnShareKernel, num_blocks, block_size, single_shared_mem * block_size, stream, data_points_data, data_offsets_data, query_points_data, query_offsets_data, block_batch_id, local_block_id, k, dists, query_out, data_out, batch_size, feature_size); device->FreeWorkspace(ctx, dists); device->FreeWorkspace(ctx, local_block_id); device->FreeWorkspace(ctx, block_batch_id); } /*! \brief Setup rng state for nn-descent */ __global__ void SetupRngKernel(hiprandState* states, const uint64_t seed, const size_t n) { size_t id = blockIdx.x * blockDim.x + threadIdx.x; if (id < n) { hiprand_init(seed, id, 0, states + id); } } /*! * \brief Randomly initialize neighbors (sampling without replacement) * for each nodes */ template __global__ void RandomInitNeighborsKernel(const FloatType* points, const IdType* offsets, IdType* central_nodes, IdType* neighbors, FloatType* dists, bool* flags, const int k, const int64_t feature_size, const int64_t batch_size, const uint64_t seed) { const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x; IdType batch_idx = 0; if (point_idx >= offsets[batch_size]) return; hiprandState state; hiprand_init(seed, point_idx, 0, &state); // find the segment location in the input batch for (IdType b = 0; b < batch_size + 1; ++b) { if (offsets[b] > point_idx) { batch_idx = b - 1; break; } } const IdType segment_size = offsets[batch_idx + 1] - offsets[batch_idx]; IdType* current_neighbors = neighbors + point_idx * k; IdType* current_central_nodes = central_nodes + point_idx * k; bool* current_flags = flags + point_idx * k; FloatType* current_dists = dists + point_idx * k; IdType segment_start = offsets[batch_idx]; // reservoir sampling for (IdType i = 0; i < k; ++i) { current_neighbors[i] = i + segment_start; current_central_nodes[i] = point_idx; } for (IdType i = k; i < segment_size; ++i) { const IdType j = static_cast(hiprand(&state) % (i + 1)); if (j < k) current_neighbors[j] = i + segment_start; } // compute distances and set flags for (IdType i = 0; i < k; ++i) { current_flags[i] = true; current_dists[i] = EuclideanDist( points + point_idx * feature_size, points + current_neighbors[i] * feature_size, feature_size); } // build heap BuildHeap(neighbors + point_idx * k, current_dists, k); } /*! \brief Randomly select candidates from current knn and reverse-knn graph for nn-descent */ template __global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidates, IdType* old_candidates, IdType* neighbors, bool* flags, const uint64_t seed, const int64_t batch_size, const int num_candidates, const int k) { const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x; IdType batch_idx = 0; if (point_idx >= offsets[batch_size]) return; hiprandState state; hiprand_init(seed, point_idx, 0, &state); // find the segment location in the input batch for (IdType b = 0; b < batch_size + 1; ++b) { if (offsets[b] > point_idx) { batch_idx = b - 1; break; } } IdType segment_start = offsets[batch_idx], segment_end = offsets[batch_idx + 1]; IdType* current_neighbors = neighbors + point_idx * k; bool* current_flags = flags + point_idx * k; // reset candidates IdType* new_candidates_ptr = new_candidates + point_idx * (num_candidates + 1); IdType* old_candidates_ptr = old_candidates + point_idx * (num_candidates + 1); new_candidates_ptr[0] = 0; old_candidates_ptr[0] = 0; // select candidates from current knn graph // here we use candidate[0] for reservoir sampling temporarily for (IdType i = 0; i < k; ++i) { IdType candidate = current_neighbors[i]; IdType* candidate_array = current_flags[i] ? new_candidates_ptr : old_candidates_ptr; IdType curr_num = candidate_array[0]; IdType* candidate_data = candidate_array + 1; // reservoir sampling if (curr_num < num_candidates) { candidate_data[curr_num] = candidate; } else { IdType pos = static_cast(hiprand(&state) % (curr_num + 1)); if (pos < num_candidates) candidate_data[pos] = candidate; } ++candidate_array[0]; } // select candidates from current reverse knn graph // here we use candidate[0] for reservoir sampling temporarily IdType index_start = segment_start * k, index_end = segment_end * k; for (IdType i = index_start; i < index_end; ++i) { if (neighbors[i] == point_idx) { IdType reverse_candidate = (i - index_start) / k + segment_start; IdType* candidate_array = flags[i] ? new_candidates_ptr : old_candidates_ptr; IdType curr_num = candidate_array[0]; IdType* candidate_data = candidate_array + 1; // reservoir sampling if (curr_num < num_candidates) { candidate_data[curr_num] = reverse_candidate; } else { IdType pos = static_cast(hiprand(&state) % (curr_num + 1)); if (pos < num_candidates) candidate_data[pos] = reverse_candidate; } ++candidate_array[0]; } } // set candidate[0] back to length if (new_candidates_ptr[0] > num_candidates) new_candidates_ptr[0] = num_candidates; if (old_candidates_ptr[0] > num_candidates) old_candidates_ptr[0] = num_candidates; // mark new_candidates as old IdType num_new_candidates = new_candidates_ptr[0]; for (IdType i = 0; i < k; ++i) { IdType neighbor_idx = current_neighbors[i]; if (current_flags[i]) { for (IdType j = 1; j < num_new_candidates + 1; ++j) { if (new_candidates_ptr[j] == neighbor_idx) { current_flags[i] = false; break; } } } } } /*! \brief Update knn graph according to selected candidates for nn-descent */ template __global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* offsets, IdType* neighbors, IdType* new_candidates, IdType* old_candidates, FloatType* distances, bool* flags, IdType* num_updates, const int64_t batch_size, const int num_candidates, const int k, const int64_t feature_size) { const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x; if (point_idx >= offsets[batch_size]) return; IdType* current_neighbors = neighbors + point_idx * k; bool* current_flags = flags + point_idx * k; FloatType* current_dists = distances + point_idx * k; IdType* new_candidates_ptr = new_candidates + point_idx * (num_candidates + 1); IdType* old_candidates_ptr = old_candidates + point_idx * (num_candidates + 1); IdType num_new_candidates = new_candidates_ptr[0]; IdType num_old_candidates = old_candidates_ptr[0]; IdType current_num_updates = 0; // process new candidates for (IdType i = 1; i <= num_new_candidates; ++i) { IdType new_c = new_candidates_ptr[i]; // new/old candidates of the current new candidate IdType* twohop_new_ptr = new_candidates + new_c * (num_candidates + 1); IdType* twohop_old_ptr = old_candidates + new_c * (num_candidates + 1); IdType num_twohop_new = twohop_new_ptr[0]; IdType num_twohop_old = twohop_old_ptr[0]; FloatType worst_dist = current_dists[0]; // new - new for (IdType j = 1; j <= num_twohop_new; ++j) { IdType twohop_new_c = twohop_new_ptr[j]; FloatType new_dist = EuclideanDistWithCheck( points + point_idx * feature_size, points + twohop_new_c * feature_size, feature_size, worst_dist); if (FlaggedHeapInsert( current_neighbors, current_dists, current_flags, twohop_new_c, new_dist, true, k, true)) { ++current_num_updates; worst_dist = current_dists[0]; } } // new - old for (IdType j = 1; j <= num_twohop_old; ++j) { IdType twohop_old_c = twohop_old_ptr[j]; FloatType new_dist = EuclideanDistWithCheck( points + point_idx * feature_size, points + twohop_old_c * feature_size, feature_size, worst_dist); if (FlaggedHeapInsert( current_neighbors, current_dists, current_flags, twohop_old_c, new_dist, true, k, true)) { ++current_num_updates; worst_dist = current_dists[0]; } } } // process old candidates for (IdType i = 1; i <= num_old_candidates; ++i) { IdType old_c = old_candidates_ptr[i]; // new candidates of the current old candidate IdType* twohop_new_ptr = new_candidates + old_c * (num_candidates + 1); IdType num_twohop_new = twohop_new_ptr[0]; FloatType worst_dist = current_dists[0]; // old - new for (IdType j = 1; j <= num_twohop_new; ++j) { IdType twohop_new_c = twohop_new_ptr[j]; FloatType new_dist = EuclideanDistWithCheck( points + point_idx * feature_size, points + twohop_new_c * feature_size, feature_size, worst_dist); if (FlaggedHeapInsert( current_neighbors, current_dists, current_flags, twohop_new_c, new_dist, true, k, true)) { ++current_num_updates; worst_dist = current_dists[0]; } } } num_updates[point_idx] = current_num_updates; } } // namespace impl template void KNN(const NDArray& data_points, const IdArray& data_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k, IdArray result, const std::string& algorithm) { if (algorithm == std::string("bruteforce")) { impl::BruteForceKNNCuda( data_points, data_offsets, query_points, query_offsets, k, result); } else if (algorithm == std::string("bruteforce-sharemem")) { impl::BruteForceKNNSharedCuda( data_points, data_offsets, query_points, query_offsets, k, result); } else { LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CUDA."; } } template void NNDescent(const NDArray& points, const IdArray& offsets, IdArray result, const int k, const int num_iters, const int num_candidates, const double delta) { hipStream_t stream = runtime::getCurrentCUDAStream(); const auto& ctx = points->ctx; auto device = runtime::DeviceAPI::Get(ctx); const int64_t num_nodes = points->shape[0]; const int64_t feature_size = points->shape[1]; const int64_t batch_size = offsets->shape[0] - 1; const IdType* offsets_data = offsets.Ptr(); const FloatType* points_data = points.Ptr(); IdType* central_nodes = result.Ptr(); IdType* neighbors = central_nodes + k * num_nodes; uint64_t seed; int warp_size = 0; CUDA_CALL(hipDeviceGetAttribute( &warp_size, hipDeviceAttributeWarpSize, ctx.device_id)); // We don't need large block sizes, since there's not much inter-thread communication int64_t block_size = warp_size; int64_t num_blocks = (num_nodes - 1) / block_size + 1; // allocate space for candidates, distances and flags // we use the first element in candidate array to represent length IdType* new_candidates = static_cast( device->AllocWorkspace(ctx, num_nodes * (num_candidates + 1) * sizeof(IdType))); IdType* old_candidates = static_cast( device->AllocWorkspace(ctx, num_nodes * (num_candidates + 1) * sizeof(IdType))); IdType* num_updates = static_cast( device->AllocWorkspace(ctx, num_nodes * sizeof(IdType))); FloatType* distances = static_cast( device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType))); bool* flags = static_cast( device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType))); size_t sum_temp_size = 0; IdType total_num_updates = 0; IdType* total_num_updates_d = static_cast( device->AllocWorkspace(ctx, sizeof(IdType))); CUDA_CALL(hipcub::DeviceReduce::Sum( nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes, stream)); IdType* sum_temp_storage = static_cast( device->AllocWorkspace(ctx, sum_temp_size)); // random initialize neighbors seed = RandomEngine::ThreadLocal()->RandInt( std::numeric_limits::max()); CUDA_KERNEL_CALL( impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, stream, points_data, offsets_data, central_nodes, neighbors, distances, flags, k, feature_size, batch_size, seed); for (int i = 0; i < num_iters; ++i) { // select candidates seed = RandomEngine::ThreadLocal()->RandInt( std::numeric_limits::max()); CUDA_KERNEL_CALL( impl::FindCandidatesKernel, num_blocks, block_size, 0, stream, offsets_data, new_candidates, old_candidates, neighbors, flags, seed, batch_size, num_candidates, k); // update CUDA_KERNEL_CALL( impl::UpdateNeighborsKernel, num_blocks, block_size, 0, stream, points_data, offsets_data, neighbors, new_candidates, old_candidates, distances, flags, num_updates, batch_size, num_candidates, k, feature_size); total_num_updates = 0; CUDA_CALL(hipcub::DeviceReduce::Sum( sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d, num_nodes, stream)); device->CopyDataFromTo( total_num_updates_d, 0, &total_num_updates, 0, sizeof(IdType), ctx, DLContext{kDLCPU, 0}, offsets->dtype); if (total_num_updates <= static_cast(delta * k * num_nodes)) { break; } } device->FreeWorkspace(ctx, new_candidates); device->FreeWorkspace(ctx, old_candidates); device->FreeWorkspace(ctx, num_updates); device->FreeWorkspace(ctx, distances); device->FreeWorkspace(ctx, flags); device->FreeWorkspace(ctx, total_num_updates_d); device->FreeWorkspace(ctx, sum_temp_storage); } template void KNN( const NDArray& data_points, const IdArray& data_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k, IdArray result, const std::string& algorithm); template void KNN( const NDArray& data_points, const IdArray& data_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k, IdArray result, const std::string& algorithm); template void KNN( const NDArray& data_points, const IdArray& data_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k, IdArray result, const std::string& algorithm); template void KNN( const NDArray& data_points, const IdArray& data_offsets, const NDArray& query_points, const IdArray& query_offsets, const int k, IdArray result, const std::string& algorithm); template void NNDescent( const NDArray& points, const IdArray& offsets, IdArray result, const int k, const int num_iters, const int num_candidates, const double delta); template void NNDescent( const NDArray& points, const IdArray& offsets, IdArray result, const int k, const int num_iters, const int num_candidates, const double delta); template void NNDescent( const NDArray& points, const IdArray& offsets, IdArray result, const int k, const int num_iters, const int num_candidates, const double delta); template void NNDescent( const NDArray& points, const IdArray& offsets, IdArray result, const int k, const int num_iters, const int num_candidates, const double delta); } // namespace transform } // namespace dgl