#include "hip/hip_runtime.h" /*! * Copyright (c) 2021 by Contributors * \file array/cuda/negative_sampling.cu * \brief rowwise sampling */ #include #include #include #include #include "./dgl_cub.cuh" #include "./utils.h" #include "../../runtime/cuda/cuda_common.h" using namespace dgl::runtime; namespace dgl { namespace aten { namespace impl { namespace { template __global__ void _GlobalUniformNegativeSamplingKernel( const IdType* __restrict__ indptr, const IdType* __restrict__ indices, IdType* __restrict__ row, IdType* __restrict__ col, int64_t num_row, int64_t num_col, int64_t num_samples, int num_trials, bool exclude_self_loops, int32_t random_seed) { int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; const int stride_x = gridDim.x * blockDim.x; hiprandStatePhilox4_32_10_t rng; // this allows generating 4 32-bit ints at a time hiprand_init(random_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng); while (tx < num_samples) { for (int i = 0; i < num_trials; ++i) { uint4 result = hiprand4(&rng); // Turns out that result.x is always 0 with the above RNG. uint64_t y_hi = result.y >> 16; uint64_t y_lo = result.y & 0xFFFF; uint64_t z = static_cast(result.z); uint64_t w = static_cast(result.w); int64_t u = static_cast(((y_lo << 32L) | z) % num_row); int64_t v = static_cast(((y_hi << 32L) | w) % num_col); if (exclude_self_loops && (u == v)) continue; // binary search of v among indptr[u:u+1] int64_t b = indptr[u], e = indptr[u + 1] - 1; bool found = false; while (b <= e) { int64_t m = (b + e) / 2; if (indices[m] == v) { found = true; break; } else if (indices[m] < v) { b = m + 1; } else { e = m - 1; } } if (!found) { row[tx] = u; col[tx] = v; break; } } tx += stride_x; } } template struct IsNotMinusOne { __device__ __forceinline__ bool operator() (const std::pair& a) { return a.first != -1; } }; /*! * \brief Sort ordered pairs in ascending order, using \a tmp_major and \a tmp_minor as * temporary buffers, each with \a n elements. */ template void SortOrderedPairs( runtime::DeviceAPI* device, DLContext ctx, IdType* major, IdType* minor, IdType* tmp_major, IdType* tmp_minor, int64_t n, hipStream_t stream) { // Sort ordered pairs in lexicographical order by two radix sorts since // cub's radix sorts are stable. // We need a 2*n auxiliary storage to store the results form the first radix sort. size_t s1 = 0, s2 = 0; void* tmp1 = nullptr; void* tmp2 = nullptr; // Radix sort by minor key first, reorder the major key in the progress. CUDA_CALL(hipcub::DeviceRadixSort::SortPairs( tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8, stream)); tmp1 = device->AllocWorkspace(ctx, s1); CUDA_CALL(hipcub::DeviceRadixSort::SortPairs( tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8, stream)); // Radix sort by major key next. CUDA_CALL(hipcub::DeviceRadixSort::SortPairs( tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8, stream)); tmp2 = (s2 > s1) ? device->AllocWorkspace(ctx, s2) : tmp1; // reuse buffer if s2 <= s1 CUDA_CALL(hipcub::DeviceRadixSort::SortPairs( tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8, stream)); if (tmp1 != tmp2) device->FreeWorkspace(ctx, tmp2); device->FreeWorkspace(ctx, tmp1); } }; // namespace template std::pair CSRGlobalUniformNegativeSampling( const CSRMatrix& csr, int64_t num_samples, int num_trials, bool exclude_self_loops, bool replace, double redundancy) { auto ctx = csr.indptr->ctx; auto dtype = csr.indptr->dtype; const int64_t num_row = csr.num_rows; const int64_t num_col = csr.num_cols; const int64_t num_actual_samples = static_cast(num_samples * (1 + redundancy)); IdArray row = Full(-1, num_actual_samples, ctx); IdArray col = Full(-1, num_actual_samples, ctx); IdArray out_row = IdArray::Empty({num_actual_samples}, dtype, ctx); IdArray out_col = IdArray::Empty({num_actual_samples}, dtype, ctx); IdType* row_data = row.Ptr(); IdType* col_data = col.Ptr(); IdType* out_row_data = out_row.Ptr(); IdType* out_col_data = out_col.Ptr(); auto device = runtime::DeviceAPI::Get(ctx); hipStream_t stream = runtime::getCurrentCUDAStream(); const int nt = cuda::FindNumThreads(num_actual_samples); const int nb = (num_actual_samples + nt - 1) / nt; std::pair result; int64_t num_out; CUDA_KERNEL_CALL(_GlobalUniformNegativeSamplingKernel, nb, nt, 0, stream, csr.indptr.Ptr(), csr.indices.Ptr(), row_data, col_data, num_row, num_col, num_actual_samples, num_trials, exclude_self_loops, RandomEngine::ThreadLocal()->RandInt32()); size_t tmp_size = 0; int64_t* num_out_cuda = static_cast(device->AllocWorkspace(ctx, sizeof(int64_t))); IsNotMinusOne op; PairIterator begin(row_data, col_data); PairIterator out_begin(out_row_data, out_col_data); CUDA_CALL(hipcub::DeviceSelect::If( nullptr, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op, stream)); void* tmp = device->AllocWorkspace(ctx, tmp_size); CUDA_CALL(hipcub::DeviceSelect::If( tmp, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op, stream)); num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda); if (!replace) { IdArray unique_row = IdArray::Empty({num_out}, dtype, ctx); IdArray unique_col = IdArray::Empty({num_out}, dtype, ctx); IdType* unique_row_data = unique_row.Ptr(); IdType* unique_col_data = unique_col.Ptr(); PairIterator unique_begin(unique_row_data, unique_col_data); SortOrderedPairs( device, ctx, out_row_data, out_col_data, unique_row_data, unique_col_data, num_out, stream); size_t tmp_size_unique = 0; void* tmp_unique = nullptr; CUDA_CALL(hipcub::DeviceSelect::Unique( nullptr, tmp_size_unique, out_begin, unique_begin, num_out_cuda, num_out, stream)); tmp_unique = (tmp_size_unique > tmp_size) ? device->AllocWorkspace(ctx, tmp_size_unique) : tmp; // reuse buffer CUDA_CALL(hipcub::DeviceSelect::Unique( tmp_unique, tmp_size_unique, out_begin, unique_begin, num_out_cuda, num_out, stream)); num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda); num_out = std::min(num_samples, num_out); result = {unique_row.CreateView({num_out}, dtype), unique_col.CreateView({num_out}, dtype)}; if (tmp_unique != tmp) device->FreeWorkspace(ctx, tmp_unique); } else { num_out = std::min(num_samples, num_out); result = {out_row.CreateView({num_out}, dtype), out_col.CreateView({num_out}, dtype)}; } device->FreeWorkspace(ctx, tmp); device->FreeWorkspace(ctx, num_out_cuda); return result; } template std::pair CSRGlobalUniformNegativeSampling( const CSRMatrix&, int64_t, int, bool, bool, double); template std::pair CSRGlobalUniformNegativeSampling( const CSRMatrix&, int64_t, int, bool, bool, double); }; // namespace impl }; // namespace aten }; // namespace dgl