Unverified Commit 86c81b4e authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Add CUDA Weighted Neighborhood Sampling (#4064)



* add weighted sampling without replacement (A-Chao)

* improve Algorithm A-Chao with block-wise prefix sum

* correctly fill out_idxs

* implement weighted sampling with replacement

* small fix

* merge host-side code of weighted/uniform sampling

* enable unit tests for cuda weighted sampling

* move thrust/cub wrapper to the cmake file

* update docs accordingly

* fix linting

* fix linting

* fix unit test

* Bump external CUB/Thrust versions

* Fix code style and update description of algorithm design

* [Feature] GPU support weighted graph neighbor sampling
commit by pengqirong(OPPO)

* merge pengqirong's implementation

* revert the change to cub and thrust

* fix linting

* use DeviceSegmentedSort for better performance

* add more comments

* add necessary notes

* add necessary notes

* resolve some comments

* define THRUST_CUB_WRAPPED_NAMESPACE

* fix doc
Co-authored-by: default avatar彭齐荣 <657017034@qq.com>
parent 17f1432a
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
url = https://github.com/KarypisLab/METIS.git url = https://github.com/KarypisLab/METIS.git
[submodule "third_party/cub"] [submodule "third_party/cub"]
path = third_party/cub path = third_party/cub
url = https://github.com/NVlabs/cub.git url = https://github.com/NVIDIA/cub.git
branch = 1.8.0
[submodule "third_party/phmap"] [submodule "third_party/phmap"]
path = third_party/phmap path = third_party/phmap
url = https://github.com/greg7mdp/parallel-hashmap.git url = https://github.com/greg7mdp/parallel-hashmap.git
......
...@@ -46,18 +46,12 @@ endif(NOT MSVC) ...@@ -46,18 +46,12 @@ endif(NOT MSVC)
if(USE_CUDA) if(USE_CUDA)
message(STATUS "Build with CUDA support") message(STATUS "Build with CUDA support")
project(dgl C CXX) project(dgl C CXX)
# see https://github.com/NVIDIA/thrust/issues/1401
add_definitions(-DTHRUST_CUB_WRAPPED_NAMESPACE=dgl)
include(cmake/modules/CUDA.cmake) include(cmake/modules/CUDA.cmake)
if ((CUDA_VERSION_MAJOR LESS 11) OR message(STATUS "Use external CUB/Thrust library for a consistent API and performance.")
((CUDA_VERSION_MAJOR EQUAL 11) AND (CUDA_VERSION_MINOR EQUAL 0))) cuda_include_directories(BEFORE "${CMAKE_SOURCE_DIR}/third_party/thrust")
# For cuda<11, use external CUB/Thrust library because CUB is not part of CUDA. cuda_include_directories(BEFORE "${CMAKE_SOURCE_DIR}/third_party/cub")
# For cuda==11.0, use external CUB/Thrust library because there is a bug in the
# official CUB library which causes invalid device ordinal error for DGL. The bug
# is fixed by https://github.com/NVIDIA/cub/commit/9143e47e048641aa0e6ddfd645bcd54ff1059939
# in 11.1.
message(STATUS "Detected CUDA of version ${CUDA_VERSION}. Use external CUB/Thrust library.")
cuda_include_directories(BEFORE "${CMAKE_SOURCE_DIR}/third_party/thrust")
cuda_include_directories(BEFORE "${CMAKE_SOURCE_DIR}/third_party/cub")
endif()
endif(USE_CUDA) endif(USE_CUDA)
# initial variables # initial variables
......
...@@ -60,7 +60,7 @@ Using CUDA UVA-based neighborhood sampling in DGL data loaders ...@@ -60,7 +60,7 @@ Using CUDA UVA-based neighborhood sampling in DGL data loaders
For the case where the graph is too large to fit onto the GPU memory, we introduce the For the case where the graph is too large to fit onto the GPU memory, we introduce the
CUDA UVA (Unified Virtual Addressing)-based sampling, in which GPUs perform the sampling CUDA UVA (Unified Virtual Addressing)-based sampling, in which GPUs perform the sampling
on the graph pinned on CPU memory via zero-copy access. on the graph pinned in CPU memory via zero-copy access.
You can enable UVA-based neighborhood sampling in DGL data loaders via: You can enable UVA-based neighborhood sampling in DGL data loaders via:
* Put the ``train_nid`` onto GPU. * Put the ``train_nid`` onto GPU.
...@@ -138,9 +138,6 @@ You can build your own GPU sampling pipelines with the following functions that ...@@ -138,9 +138,6 @@ You can build your own GPU sampling pipelines with the following functions that
operating on GPU: operating on GPU:
* :func:`dgl.sampling.sample_neighbors` * :func:`dgl.sampling.sample_neighbors`
* Only has support for uniform sampling; non-uniform sampling can only run on CPU.
* :func:`dgl.sampling.random_walk` * :func:`dgl.sampling.random_walk`
Subgraph extraction ops: Subgraph extraction ops:
......
...@@ -54,8 +54,6 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -54,8 +54,6 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
The features must be non-negative floats, and the sum of the features of The features must be non-negative floats, and the sum of the features of
inbound/outbound edges for every node must be positive (though they don't have inbound/outbound edges for every node must be positive (though they don't have
to sum up to one). Otherwise, the result will be undefined. to sum up to one). Otherwise, the result will be undefined.
If :attr:`prob` is not None, GPU sampling is not supported.
replace : bool, optional replace : bool, optional
If True, sample with replacement. If True, sample with replacement.
copy_ndata: bool, optional copy_ndata: bool, optional
...@@ -163,6 +161,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -163,6 +161,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
Node/edge features are not preserved. The original IDs of Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph. the sampled edges are stored as the `dgl.EID` feature in the returned graph.
GPU sampling is supported for this function. Refer to :ref:`guide-minibatch-gpu-sampling`
for more details.
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
...@@ -193,8 +194,6 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -193,8 +194,6 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
The features must be non-negative floats, and the sum of the features of The features must be non-negative floats, and the sum of the features of
inbound/outbound edges for every node must be positive (though they don't have inbound/outbound edges for every node must be positive (though they don't have
to sum up to one). Otherwise, the result will be undefined. to sum up to one). Otherwise, the result will be undefined.
If :attr:`prob` is not None, GPU sampling is not supported.
exclude_edges: tensor or dict exclude_edges: tensor or dict
Edge IDs to exclude during sampling neighbors for the seed nodes. Edge IDs to exclude during sampling neighbors for the seed nodes.
......
...@@ -549,11 +549,12 @@ COOMatrix CSRRowWiseSampling( ...@@ -549,11 +549,12 @@ COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) { CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
COOMatrix ret; COOMatrix ret;
if (IsNullArray(prob)) { if (IsNullArray(prob)) {
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", { ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSamplingUniform", {
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace); ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
}); });
} else { } else {
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSampling", { CHECK_SAME_CONTEXT(rows, prob);
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", { ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>( ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob, replace); mat, rows, num_samples, prob, replace);
......
...@@ -7,13 +7,11 @@ ...@@ -7,13 +7,11 @@
#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_ #ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_
#define DGL_ARRAY_CUDA_DGL_CUB_CUH_ #define DGL_ARRAY_CUDA_DGL_CUB_CUH_
// include cub in a safe manner // This should be defined in CMakeLists.txt
#define CUB_NS_PREFIX namespace dgl { #ifndef THRUST_CUB_WRAPPED_NAMESPACE
#define CUB_NS_POSTFIX } static_assert(false, "THRUST_CUB_WRAPPED_NAMESPACE must be defined for DGL.");
#define CUB_NS_QUALIFIER ::dgl::cub #endif
#include "cub/cub.cuh" #include "cub/cub.cuh"
#undef CUB_NS_QUALIFIER
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
#endif #endif
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* \file array/cuda/rowwise_sampling.cu * \file array/cuda/rowwise_sampling.cu
* \brief rowwise sampling * \brief uniform rowwise sampling
*/ */
#include <dgl/random.h> #include <dgl/random.h>
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "../../array/cuda/atomic.cuh" #include "../../array/cuda/atomic.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
using namespace dgl::aten::cuda; using namespace dgl::aten::cuda;
namespace dgl { namespace dgl {
...@@ -21,7 +22,7 @@ namespace impl { ...@@ -21,7 +22,7 @@ namespace impl {
namespace { namespace {
constexpr int CTA_SIZE = 128; constexpr int BLOCK_SIZE = 128;
/** /**
* @brief Compute the size of each row in the sampled CSR, without replacement. * @brief Compute the size of each row in the sampled CSR, without replacement.
...@@ -41,14 +42,14 @@ __global__ void _CSRRowWiseSampleDegreeKernel( ...@@ -41,14 +42,14 @@ __global__ void _CSRRowWiseSampleDegreeKernel(
const IdType * const in_rows, const IdType * const in_rows,
const IdType * const in_ptr, const IdType * const in_ptr,
IdType * const out_deg) { IdType * const out_deg) {
const int tIdx = threadIdx.x + blockIdx.x*blockDim.x; const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) { if (tIdx < num_rows) {
const int in_row = in_rows[tIdx]; const int in_row = in_rows[tIdx];
const int out_row = tIdx; const int out_row = tIdx;
out_deg[out_row] = min(static_cast<IdType>(num_picks), in_ptr[in_row+1]-in_ptr[in_row]); out_deg[out_row] = min(static_cast<IdType>(num_picks), in_ptr[in_row + 1] - in_ptr[in_row]);
if (out_row == num_rows-1) { if (out_row == num_rows - 1) {
// make the prefixsum work // make the prefixsum work
out_deg[num_rows] = 0; out_deg[num_rows] = 0;
} }
...@@ -73,19 +74,19 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel( ...@@ -73,19 +74,19 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
const IdType * const in_rows, const IdType * const in_rows,
const IdType * const in_ptr, const IdType * const in_ptr,
IdType * const out_deg) { IdType * const out_deg) {
const int tIdx = threadIdx.x + blockIdx.x*blockDim.x; const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) { if (tIdx < num_rows) {
const int64_t in_row = in_rows[tIdx]; const int64_t in_row = in_rows[tIdx];
const int64_t out_row = tIdx; const int64_t out_row = tIdx;
if (in_ptr[in_row+1]-in_ptr[in_row] == 0) { if (in_ptr[in_row + 1] - in_ptr[in_row] == 0) {
out_deg[out_row] = 0; out_deg[out_row] = 0;
} else { } else {
out_deg[out_row] = static_cast<IdType>(num_picks); out_deg[out_row] = static_cast<IdType>(num_picks);
} }
if (out_row == num_rows-1) { if (out_row == num_rows - 1) {
// make the prefixsum work // make the prefixsum work
out_deg[num_rows] = 0; out_deg[num_rows] = 0;
} }
...@@ -93,11 +94,10 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel( ...@@ -93,11 +94,10 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
} }
/** /**
* @brief Perform row-wise sampling on a CSR matrix, and generate a COO matrix, * @brief Perform row-wise uniform sampling on a CSR matrix,
* without replacement. * and generate a COO matrix, without replacement.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam BLOCK_CTAS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock. * @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use. * @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
...@@ -111,8 +111,8 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel( ...@@ -111,8 +111,8 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
* @param out_cols The columns of the output COO (output). * @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_idxs The data array of the output COO (output).
*/ */
template<typename IdType, int BLOCK_CTAS, int TILE_SIZE> template<typename IdType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleKernel( __global__ void _CSRRowWiseSampleUniformKernel(
const uint64_t rand_seed, const uint64_t rand_seed,
const int64_t num_picks, const int64_t num_picks,
const int64_t num_rows, const int64_t num_rows,
...@@ -125,68 +125,62 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -125,68 +125,62 @@ __global__ void _CSRRowWiseSampleKernel(
IdType * const out_cols, IdType * const out_cols,
IdType * const out_idxs) { IdType * const out_idxs) {
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == CTA_SIZE); assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y; int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows); const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng; curandStatePhilox4_32_10_t rng;
curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng); curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
while (out_row < last_row) { while (out_row < last_row) {
const int64_t row = in_rows[out_row]; const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row]; const int64_t in_row_start = in_ptr[row];
const int64_t deg = in_ptr[row+1] - in_row_start; const int64_t deg = in_ptr[row + 1] - in_row_start;
const int64_t out_row_start = out_ptr[out_row]; const int64_t out_row_start = out_ptr[out_row];
if (deg <= num_picks) { if (deg <= num_picks) {
// just copy row // just copy row when there is not enough nodes to sample.
for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) { for (int idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
const IdType in_idx = in_row_start+idx; const IdType in_idx = in_row_start + idx;
out_rows[out_row_start+idx] = row; out_rows[out_row_start + idx] = row;
out_cols[out_row_start+idx] = in_index[in_idx]; out_cols[out_row_start + idx] = in_index[in_idx];
out_idxs[out_row_start+idx] = data ? data[in_idx] : in_idx; out_idxs[out_row_start + idx] = data ? data[in_idx] : in_idx;
} }
} else { } else {
// generate permutation list via reservoir algorithm // generate permutation list via reservoir algorithm
for (int idx = threadIdx.x; idx < num_picks; idx+=CTA_SIZE) { for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
out_idxs[out_row_start+idx] = idx; out_idxs[out_row_start + idx] = idx;
} }
__syncthreads(); __syncthreads();
for (int idx = num_picks+threadIdx.x; idx < deg; idx+=CTA_SIZE) { for (int idx = num_picks + threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
const int num = curand(&rng)%(idx+1); const int num = curand(&rng) % (idx + 1);
if (num < num_picks) { if (num < num_picks) {
// use max so as to achieve the replacement order the serial // use max so as to achieve the replacement order the serial
// algorithm would have // algorithm would have
AtomicMax(out_idxs+out_row_start+num, idx); AtomicMax(out_idxs + out_row_start + num, idx);
} }
} }
__syncthreads(); __syncthreads();
// copy permutation over // copy permutation over
for (int idx = threadIdx.x; idx < num_picks; idx += CTA_SIZE) { for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
const IdType perm_idx = out_idxs[out_row_start+idx]+in_row_start; const IdType perm_idx = out_idxs[out_row_start + idx] + in_row_start;
out_rows[out_row_start+idx] = row; out_rows[out_row_start + idx] = row;
out_cols[out_row_start+idx] = in_index[perm_idx]; out_cols[out_row_start + idx] = in_index[perm_idx];
if (data) { out_idxs[out_row_start + idx] = data ? data[perm_idx] : perm_idx;
out_idxs[out_row_start+idx] = data[perm_idx];
}
} }
} }
out_row += 1;
out_row += BLOCK_CTAS;
} }
} }
/** /**
* @brief Perform row-wise sampling on a CSR matrix, and generate a COO matrix, * @brief Perform row-wise uniform sampling on a CSR matrix,
* with replacement. * and generate a COO matrix, with replacement.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam BLOCK_CTAS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock. * @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use. * @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
...@@ -200,8 +194,8 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -200,8 +194,8 @@ __global__ void _CSRRowWiseSampleKernel(
* @param out_cols The columns of the output COO (output). * @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_idxs The data array of the output COO (output).
*/ */
template<typename IdType, int BLOCK_CTAS, int TILE_SIZE> template<typename IdType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleReplaceKernel( __global__ void _CSRRowWiseSampleUniformReplaceKernel(
const uint64_t rand_seed, const uint64_t rand_seed,
const int64_t num_picks, const int64_t num_picks,
const int64_t num_rows, const int64_t num_rows,
...@@ -214,39 +208,37 @@ __global__ void _CSRRowWiseSampleReplaceKernel( ...@@ -214,39 +208,37 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
IdType * const out_cols, IdType * const out_cols,
IdType * const out_idxs) { IdType * const out_idxs) {
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == CTA_SIZE); assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y; int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows); const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng; curandStatePhilox4_32_10_t rng;
curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng); curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
while (out_row < last_row) { while (out_row < last_row) {
const int64_t row = in_rows[out_row]; const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row]; const int64_t in_row_start = in_ptr[row];
const int64_t out_row_start = out_ptr[out_row]; const int64_t out_row_start = out_ptr[out_row];
const int64_t deg = in_ptr[row + 1] - in_row_start;
const int64_t deg = in_ptr[row+1] - in_row_start;
if (deg > 0) { if (deg > 0) {
// each thread then blindly copies in rows only if deg > 0. // each thread then blindly copies in rows only if deg > 0.
for (int idx = threadIdx.x; idx < num_picks; idx += CTA_SIZE) { for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
const int64_t edge = curand(&rng) % deg; const int64_t edge = curand(&rng) % deg;
const int64_t out_idx = out_row_start+idx; const int64_t out_idx = out_row_start + idx;
out_rows[out_idx] = row; out_rows[out_idx] = row;
out_cols[out_idx] = in_index[in_row_start+edge]; out_cols[out_idx] = in_index[in_row_start + edge];
out_idxs[out_idx] = data ? data[in_row_start+edge] : in_row_start+edge; out_idxs[out_idx] = data ? data[in_row_start + edge] : in_row_start + edge;
} }
} }
out_row += BLOCK_CTAS; out_row += 1;
} }
} }
} // namespace } // namespace
/////////////////////////////// CSR ///////////////////////////////
///////////////////////////// CSR sampling //////////////////////////
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
...@@ -277,22 +269,26 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, ...@@ -277,22 +269,26 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
// compute degree // compute degree
IdType * out_deg = static_cast<IdType*>( IdType * out_deg = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows+1)*sizeof(IdType))); device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
if (replace) { if (replace) {
const dim3 block(512); const dim3 block(512);
const dim3 grid((num_rows+block.x-1)/block.x); const dim3 grid((num_rows + block.x - 1) / block.x);
_CSRRowWiseSampleDegreeReplaceKernel<<<grid, block, 0, stream>>>( CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeReplaceKernel,
grid, block, 0, stream,
num_picks, num_rows, slice_rows, in_ptr, out_deg); num_picks, num_rows, slice_rows, in_ptr, out_deg);
} else { } else {
const dim3 block(512); const dim3 block(512);
const dim3 grid((num_rows+block.x-1)/block.x); const dim3 grid((num_rows + block.x - 1) / block.x);
_CSRRowWiseSampleDegreeKernel<<<grid, block, 0, stream>>>( CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeKernel,
grid, block, 0, stream,
num_picks, num_rows, slice_rows, in_ptr, out_deg); num_picks, num_rows, slice_rows, in_ptr, out_deg);
} }
// fill out_ptr // fill out_ptr
IdType * out_ptr = static_cast<IdType*>( IdType * out_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows+1)*sizeof(IdType))); device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
size_t prefix_temp_size = 0; size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size,
out_deg, out_deg,
...@@ -314,24 +310,25 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, ...@@ -314,24 +310,25 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
// a cudaevent // a cudaevent
IdType new_len; IdType new_len;
device->CopyDataFromTo(out_ptr, num_rows*sizeof(new_len), &new_len, 0, device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
sizeof(new_len), sizeof(new_len),
ctx, ctx,
DGLContext{kDLCPU, 0}, DGLContext{kDLCPU, 0},
mat.indptr->dtype, mat.indptr->dtype,
stream); stream);
CUDA_CALL(cudaEventRecord(copyEvent, stream)); CUDA_CALL(cudaEventRecord(copyEvent, stream));
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000); const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
// select edges // select edges
if (replace) { // the number of rows each thread block will cover
constexpr int BLOCK_CTAS = 128/CTA_SIZE; constexpr int TILE_SIZE = 128 / BLOCK_SIZE;
// the number of rows each thread block will cover if (replace) { // with replacement
constexpr int TILE_SIZE = BLOCK_CTAS; const dim3 block(BLOCK_SIZE);
const dim3 block(CTA_SIZE, BLOCK_CTAS); const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE); CUDA_KERNEL_CALL(
_CSRRowWiseSampleReplaceKernel<IdType, BLOCK_CTAS, TILE_SIZE><<<grid, block, 0, stream>>>( (_CSRRowWiseSampleUniformReplaceKernel<IdType, TILE_SIZE>),
grid, block, 0, stream,
random_seed, random_seed,
num_picks, num_picks,
num_rows, num_rows,
...@@ -343,13 +340,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, ...@@ -343,13 +340,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
out_rows, out_rows,
out_cols, out_cols,
out_idxs); out_idxs);
} else { } else { // without replacement
constexpr int BLOCK_CTAS = 128/CTA_SIZE; const dim3 block(BLOCK_SIZE);
// the number of rows each thread block will cover const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
constexpr int TILE_SIZE = BLOCK_CTAS; CUDA_KERNEL_CALL(
const dim3 block(CTA_SIZE, BLOCK_CTAS); (_CSRRowWiseSampleUniformKernel<IdType, TILE_SIZE>),
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE); grid, block, 0, stream,
_CSRRowWiseSampleKernel<IdType, BLOCK_CTAS, TILE_SIZE><<<grid, block, 0, stream>>>(
random_seed, random_seed,
num_picks, num_picks,
num_rows, num_rows,
......
/*!
* Copyright (c) 2022 by Contributors
* \file array/cuda/rowwise_sampling_prob.cu
* \brief weighted rowwise sampling. The degree computing kernels and
* host-side functions are partially borrowed from the uniform rowwise
* sampling code rowwise_sampling.cu.
* \author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/
#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
#include <curand_kernel.h>
#include <numeric>
#include "./dgl_cub.cuh"
#include "../../array/cuda/atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
// require CUB 1.17 to use DeviceSegmentedSort
static_assert(CUB_VERSION >= 101700);
using namespace dgl::aten::cuda;
namespace dgl {
namespace aten {
namespace impl {
namespace {
constexpr int BLOCK_SIZE = 128;
/**
* @brief Compute the size of each row in the sampled CSR, without replacement.
* temp_deg is calculated for rows with deg > num_picks.
* For these rows, we will calculate their A-Res values and sort them to get top-num_picks.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by `in_rows` (output).
* @param temp_deg The size of each row in the input matrix, as indexed by `in_rows` (output).
*/
template<typename IdType>
__global__ void _CSRRowWiseSampleDegreeKernel(
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
IdType * const out_deg,
IdType * const temp_deg) {
const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) {
const int64_t in_row = in_rows[tIdx];
const int64_t out_row = tIdx;
const IdType deg = in_ptr[in_row + 1] - in_ptr[in_row];
// temp_deg is used to generate ares_ptr
temp_deg[out_row] = deg > static_cast<IdType>(num_picks) ? deg : 0;
out_deg[out_row] = min(static_cast<IdType>(num_picks), deg);
if (out_row == num_rows - 1) {
// make the prefixsum work
out_deg[num_rows] = 0;
temp_deg[num_rows] = 0;
}
}
}
/**
* @brief Compute the size of each row in the sampled CSR, with replacement.
* We need the actual in degree of each row to store CDF values.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by `in_rows` (output).
* @param temp_deg The size of each row in the input matrix, as indexed by `in_rows` (output).
*/
template<typename IdType>
__global__ void _CSRRowWiseSampleDegreeReplaceKernel(
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
IdType * const out_deg,
IdType * const temp_deg) {
const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) {
const int64_t in_row = in_rows[tIdx];
const int64_t out_row = tIdx;
const IdType deg = in_ptr[in_row + 1] - in_ptr[in_row];
temp_deg[out_row] = deg;
out_deg[out_row] = deg == 0 ? 0 : static_cast<IdType>(num_picks);
if (out_row == num_rows - 1) {
// make the prefixsum work
out_deg[num_rows] = 0;
temp_deg[num_rows] = 0;
}
}
}
/**
* @brief Equivalent to numpy expression: array[idx[off:off + len]]
*
* @tparam IdType The ID type used for indices.
* @tparam FloatType The float type used for array values.
* @param array The array to be selected.
* @param idx_data The index mapping array.
* @param index The index of value to be selected.
* @param offset The offset to start.
* @param out The selected value (output).
*/
template<typename IdType, typename FloatType>
__device__ void _DoubleSlice(
const FloatType * const array,
const IdType * const idx_data,
const IdType idx,
const IdType offset,
FloatType* const out) {
if (idx_data) {
*out = array[idx_data[offset + idx]];
} else {
*out = array[offset + idx];
}
}
/**
* @brief Compute A-Res value. A-Res value needs to be calculated only if deg
* is greater than num_picks in weighted rowwise sampling without replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param data The data array of the input CSR.
* @param prob The probability array of the input CSR.
* @param ares_ptr The offset to write each row to in the A-res array.
* @param ares_idxs The A-Res value corresponding index array, the index of input CSR (output).
* @param ares The A-Res value array (output).
* @author pengqirong (OPPO)
*/
template<typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRAResValueKernel(
const uint64_t rand_seed,
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
const IdType * const data,
const FloatType * const prob,
const IdType * const ares_ptr,
IdType * const ares_idxs,
FloatType * const ares) {
int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
while (out_row < last_row) {
const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row];
const int64_t deg = in_ptr[row + 1] - in_row_start;
// A-Res value needs to be calculated only if deg is greater than num_picks
// in weighted rowwise sampling without replacement
if (deg > num_picks) {
const int64_t ares_row_start = ares_ptr[out_row];
for (int64_t idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
const int64_t in_idx = in_row_start + idx;
const int64_t ares_idx = ares_row_start + idx;
FloatType item_prob;
_DoubleSlice<IdType, FloatType>(prob, data, idx, in_row_start, &item_prob);
// compute A-Res value
ares[ares_idx] = static_cast<FloatType>(__powf(curand_uniform(&rng), 1.0f / item_prob));
ares_idxs[ares_idx] = static_cast<IdType>(in_idx);
}
}
out_row += 1;
}
}
/**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix,
* without replacement. After sorting, we select top-num_picks items.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_cols The columns array of the input CSR.
* @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param ares_ptr The offset to write each row to in the ares array.
* @param sort_ares_idxs The sorted A-Res value corresponding index array, the index of input CSR.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
* @author pengqirong (OPPO)
*/
template<typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleKernel(
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
const IdType * const in_cols,
const IdType * const data,
const IdType * const out_ptr,
const IdType * const ares_ptr,
const IdType * const sort_ares_idxs,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs) {
// we assign one warp per row
assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
while (out_row < last_row) {
const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row];
const int64_t out_row_start = out_ptr[out_row];
const int64_t deg = in_ptr[row + 1] - in_row_start;
if (deg > num_picks) {
const int64_t ares_row_start = ares_ptr[out_row];
for (int64_t idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
// get in and out index, the in_idx is one of top num_picks A-Res value
// corresponding index in input CSR.
const int64_t out_idx = out_row_start + idx;
const int64_t ares_idx = ares_row_start + idx;
const int64_t in_idx = sort_ares_idxs[ares_idx];
// copy permutation over
out_rows[out_idx] = static_cast<IdType>(row);
out_cols[out_idx] = in_cols[in_idx];
out_idxs[out_idx] = static_cast<IdType>(data ? data[in_idx] : in_idx);
}
} else {
for (int64_t idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
// get in and out index
const int64_t out_idx = out_row_start + idx;
const int64_t in_idx = in_row_start + idx;
// copy permutation over
out_rows[out_idx] = static_cast<IdType>(row);
out_cols[out_idx] = in_cols[in_idx];
out_idxs[out_idx] = static_cast<IdType>(data ? data[in_idx] : in_idx);
}
}
out_row += 1;
}
}
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
template<typename FloatType>
struct BlockPrefixCallbackOp {
// Running prefix
FloatType running_total;
// Constructor
__device__ BlockPrefixCallbackOp(FloatType running_total) : running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ FloatType operator()(FloatType block_aggregate) {
FloatType old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
/**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix,
* with replacement. We store the CDF (unnormalized) of all neighbors of a row
* in global memory and use binary search to find inverse indices as selected items.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_cols The columns array of the input CSR.
* @param data The data array of the input CSR.
* @param prob The probability array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param cdf_ptr The offset of each cdf segment.
* @param cdf The global buffer to store cdf segments.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
* @author pengqirong (OPPO)
*/
template<typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleReplaceKernel(
const uint64_t rand_seed,
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
const IdType * const in_cols,
const IdType * const data,
const FloatType * const prob,
const IdType * const out_ptr,
const IdType * const cdf_ptr,
FloatType * const cdf,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs
) {
// we assign one warp per row
assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
while (out_row < last_row) {
const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row];
const int64_t out_row_start = out_ptr[out_row];
const int64_t cdf_row_start = cdf_ptr[out_row];
const int64_t deg = in_ptr[row + 1] - in_row_start;
const FloatType MIN_THREAD_DATA = static_cast<FloatType>(0.0f);
if (deg > 0) {
// Specialize BlockScan for a 1D block of BLOCK_SIZE threads
typedef cub::BlockScan<FloatType, BLOCK_SIZE> BlockScan;
// Allocate shared memory for BlockScan
__shared__ typename BlockScan::TempStorage temp_storage;
// Initialize running total
BlockPrefixCallbackOp<FloatType> prefix_op(MIN_THREAD_DATA);
int64_t max_iter = (1 + (deg - 1) / BLOCK_SIZE) * BLOCK_SIZE;
// Have the block iterate over segments of items
for (int64_t idx = threadIdx.x; idx < max_iter; idx += BLOCK_SIZE) {
// Load a segment of consecutive items that are blocked across threads
FloatType thread_data;
if (idx < deg)
_DoubleSlice<IdType, FloatType>(prob, data, idx, in_row_start, &thread_data);
else
thread_data = MIN_THREAD_DATA;
thread_data = max(thread_data, MIN_THREAD_DATA);
// Collectively compute the block-wide inclusive prefix sum
BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, prefix_op);
__syncthreads();
// Store scanned items to cdf array
if (idx < deg) {
cdf[cdf_row_start + idx] = thread_data;
}
}
__syncthreads();
for (int64_t idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
// get random value
FloatType sum = cdf[cdf_row_start + deg - 1];
FloatType rand = static_cast<FloatType>(curand_uniform(&rng) * sum);
// get the offset of the first value within cdf array which is greater than random value.
int64_t item = cub::UpperBound<FloatType*, int64_t, FloatType>(
&cdf[cdf_row_start], deg, rand);
item = min(item, deg - 1);
// get in and out index
const int64_t in_idx = in_row_start + item;
const int64_t out_idx = out_row_start + idx;
// copy permutation over
out_rows[out_idx] = static_cast<IdType>(row);
out_cols[out_idx] = in_cols[in_idx];
out_idxs[out_idx] = static_cast<IdType>(data ? data[in_idx] : in_idx);
}
}
out_row += 1;
}
}
} // namespace
/////////////////////////////// CSR ///////////////////////////////
/**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix.
* Use CDF sampling algorithm for with replacement:
* 1) Calculate the CDF of all neighbor's prob.
* 2) For each [0, num_picks), generate a rand ~ U(0, 1).
* Use binary search to find its index in the CDF array as a chosen item.
* Use A-Res sampling algorithm for without replacement:
* 1) For rows with deg > num_picks, calculate A-Res values for all neighbors.
* 2) Sort the A-Res array and select top-num_picks as chosen items.
*
* @tparam XPU The device type used for matrices.
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @param mat The CSR matrix.
* @param rows The set of rows to pick.
* @param num_picks The number of non-zeros to pick per row.
* @param prob The probability array of the input CSR.
* @param replace Is replacement sampling?
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat,
IdArray rows,
int64_t num_picks,
FloatArray prob,
bool replace) {
const auto& ctx = rows->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
// TODO(dlasalle): Once the device api supports getting the stream from the
// context, that should be used instead of the default stream here.
cudaStream_t stream = 0;
const int64_t num_rows = rows->shape[0];
const IdType * const slice_rows = static_cast<const IdType*>(rows->data);
IdArray picked_row = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_col = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
const IdType * const in_ptr = static_cast<const IdType*>(mat.indptr->data);
const IdType * const in_cols = static_cast<const IdType*>(mat.indices->data);
IdType* const out_rows = static_cast<IdType*>(picked_row->data);
IdType* const out_cols = static_cast<IdType*>(picked_col->data);
IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);
const IdType* const data = CSRHasData(mat) ?
static_cast<IdType*>(mat.data->data) : nullptr;
const FloatType* const prob_data = static_cast<const FloatType*>(prob->data);
// compute degree
// out_deg: the size of each row in the sampled matrix
// temp_deg: the size of each row we will manipulate in sampling
// 1) for w/o replacement: in degree if it's greater than num_picks else 0
// 2) for w/ replacement: in degree
IdType * out_deg = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
IdType * temp_deg = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
if (replace) {
const dim3 block(512);
const dim3 grid((num_rows + block.x - 1) / block.x);
CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeReplaceKernel,
grid, block, 0, stream,
num_picks, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
} else {
const dim3 block(512);
const dim3 grid((num_rows + block.x - 1) / block.x);
CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeKernel,
grid, block, 0, stream,
num_picks, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
}
// fill temp_ptr
IdType * temp_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1)*sizeof(IdType)));
size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size,
temp_deg,
temp_ptr,
num_rows + 1,
stream));
void * prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
temp_deg,
temp_ptr,
num_rows + 1,
stream));
device->FreeWorkspace(ctx, prefix_temp);
device->FreeWorkspace(ctx, temp_deg);
// TODO(Xin): The copy here is too small, and the overhead of creating
// cuda events cannot be ignored. Just use synchronized copy.
IdType temp_len;
device->CopyDataFromTo(temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0,
sizeof(temp_len),
ctx,
DGLContext{kDLCPU, 0},
mat.indptr->dtype,
stream);
device->StreamSync(ctx, stream);
// fill out_ptr
IdType * out_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows+1)*sizeof(IdType)));
prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size,
out_deg,
out_ptr,
num_rows+1,
stream));
prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
out_deg,
out_ptr,
num_rows+1,
stream));
device->FreeWorkspace(ctx, prefix_temp);
device->FreeWorkspace(ctx, out_deg);
cudaEvent_t copyEvent;
CUDA_CALL(cudaEventCreate(&copyEvent));
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
// a cudaevent
IdType new_len;
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
sizeof(new_len),
ctx,
DGLContext{kDLCPU, 0},
mat.indptr->dtype,
stream);
CUDA_CALL(cudaEventRecord(copyEvent, stream));
// allocate workspace
// 1) for w/ replacement, it's a global buffer to store cdf segments (one segment for each row).
// 2) for w/o replacement, it's used to store a-res segments (one segment for
// each row with degree > num_picks)
FloatType * temp = static_cast<FloatType*>(
device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));
const uint64_t rand_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
// select edges
// the number of rows each thread block will cover
constexpr int TILE_SIZE = 128 / BLOCK_SIZE;
if (replace) { // with replacement.
const dim3 block(BLOCK_SIZE);
const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
CUDA_KERNEL_CALL(
(_CSRRowWiseSampleReplaceKernel<IdType, FloatType, TILE_SIZE>),
grid, block, 0, stream,
rand_seed,
num_picks,
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
prob_data,
out_ptr,
temp_ptr,
temp,
out_rows,
out_cols,
out_idxs);
device->FreeWorkspace(ctx, temp);
} else { // without replacement
IdType* temp_idxs = static_cast<IdType*>(
device->AllocWorkspace(ctx, (temp_len) * sizeof(IdType)));
// Compute A-Res value. A-Res value needs to be calculated only if deg
// is greater than num_picks in weighted rowwise sampling without replacement.
const dim3 block(BLOCK_SIZE);
const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
CUDA_KERNEL_CALL(
(_CSRAResValueKernel<IdType, FloatType, TILE_SIZE>),
grid, block, 0, stream,
rand_seed,
num_picks,
num_rows,
slice_rows,
in_ptr,
data,
prob_data,
temp_ptr,
temp_idxs,
temp);
// sort A-Res value array.
FloatType* sort_temp = static_cast<FloatType*>(
device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));
IdType* sort_temp_idxs = static_cast<IdType*>(
device->AllocWorkspace(ctx, temp_len * sizeof(IdType)));
cub::DoubleBuffer<FloatType> sort_keys(temp, sort_temp);
cub::DoubleBuffer<IdType> sort_values(temp_idxs, sort_temp_idxs);
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(
d_temp_storage,
temp_storage_bytes,
sort_keys,
sort_values,
temp_len,
num_rows,
temp_ptr,
temp_ptr + 1));
d_temp_storage = device->AllocWorkspace(ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(
d_temp_storage,
temp_storage_bytes,
sort_keys,
sort_values,
temp_len,
num_rows,
temp_ptr,
temp_ptr + 1));
device->FreeWorkspace(ctx, d_temp_storage);
device->FreeWorkspace(ctx, temp);
device->FreeWorkspace(ctx, temp_idxs);
device->FreeWorkspace(ctx, sort_temp);
device->FreeWorkspace(ctx, sort_temp_idxs);
// select tok-num_picks as results
CUDA_KERNEL_CALL(
(_CSRRowWiseSampleKernel<IdType, FloatType, TILE_SIZE>),
grid, block, 0, stream,
num_picks,
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
out_ptr,
temp_ptr,
sort_values.Current(),
out_rows,
out_cols,
out_idxs);
}
device->FreeWorkspace(ctx, temp_ptr);
device->FreeWorkspace(ctx, out_ptr);
// wait for copying `new_len` to finish
CUDA_CALL(cudaEventSynchronize(copyEvent));
CUDA_CALL(cudaEventDestroy(copyEvent));
picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);
return COOMatrix(mat.num_rows, mat.num_cols, picked_row,
picked_col, picked_idx);
}
template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
} // namespace impl
} // namespace aten
} // namespace dgl
...@@ -625,12 +625,10 @@ def test_sample_neighbors_noprob(): ...@@ -625,12 +625,10 @@ def test_sample_neighbors_noprob():
_test_sample_neighbors(False, None) _test_sample_neighbors(False, None)
#_test_sample_neighbors(True) #_test_sample_neighbors(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors with probability is not implemented")
def test_sample_neighbors_prob(): def test_sample_neighbors_prob():
_test_sample_neighbors(False, 'prob') _test_sample_neighbors(False, 'prob')
#_test_sample_neighbors(True) #_test_sample_neighbors(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_outedge(): def test_sample_neighbors_outedge():
_test_sample_neighbors_outedge(False) _test_sample_neighbors_outedge(False)
#_test_sample_neighbors_outedge(True) #_test_sample_neighbors_outedge(True)
...@@ -645,9 +643,8 @@ def test_sample_neighbors_topk_outedge(): ...@@ -645,9 +643,8 @@ def test_sample_neighbors_topk_outedge():
_test_sample_neighbors_topk_outedge(False) _test_sample_neighbors_topk_outedge(False)
#_test_sample_neighbors_topk_outedge(True) #_test_sample_neighbors_topk_outedge(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_with_0deg(): def test_sample_neighbors_with_0deg():
g = dgl.graph(([], []), num_nodes=5) g = dgl.graph(([], []), num_nodes=5).to(F.ctx())
sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=False) sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=False)
assert sg.number_of_edges() == 0 assert sg.number_of_edges() == 0
sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=True) sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='in', replace=True)
...@@ -884,7 +881,6 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction): ...@@ -884,7 +881,6 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
assert fail assert fail
@pytest.mark.parametrize('dtype', ['int32', 'int64']) @pytest.mark.parametrize('dtype', ['int32', 'int64'])
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_exclude_edges_heteroG(dtype): def test_sample_neighbors_exclude_edges_heteroG(dtype):
d_i_d_u_nodes = F.zerocopy_from_numpy(np.unique(np.random.randint(300, size=100, dtype=dtype))) d_i_d_u_nodes = F.zerocopy_from_numpy(np.unique(np.random.randint(300, size=100, dtype=dtype)))
d_i_d_v_nodes = F.zerocopy_from_numpy(np.random.randint(25, size=d_i_d_u_nodes.shape, dtype=dtype)) d_i_d_v_nodes = F.zerocopy_from_numpy(np.random.randint(25, size=d_i_d_u_nodes.shape, dtype=dtype))
...@@ -897,7 +893,7 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype): ...@@ -897,7 +893,7 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
('drug', 'interacts', 'drug'): (d_i_d_u_nodes, d_i_d_v_nodes), ('drug', 'interacts', 'drug'): (d_i_d_u_nodes, d_i_d_v_nodes),
('drug', 'interacts', 'gene'): (d_i_g_u_nodes, d_i_g_v_nodes), ('drug', 'interacts', 'gene'): (d_i_g_u_nodes, d_i_g_v_nodes),
('drug', 'treats', 'disease'): (d_t_d_u_nodes, d_t_d_v_nodes) ('drug', 'treats', 'disease'): (d_t_d_u_nodes, d_t_d_v_nodes)
}) }).to(F.ctx())
(U, V, EID) = (0, 1, 2) (U, V, EID) = (0, 1, 2)
...@@ -950,11 +946,10 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype): ...@@ -950,11 +946,10 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
etype=('drug','treats','disease')))) etype=('drug','treats','disease'))))
@pytest.mark.parametrize('dtype', ['int32', 'int64']) @pytest.mark.parametrize('dtype', ['int32', 'int64'])
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_exclude_edges_homoG(dtype): def test_sample_neighbors_exclude_edges_homoG(dtype):
u_nodes = F.zerocopy_from_numpy(np.unique(np.random.randint(300,size=100, dtype=dtype))) u_nodes = F.zerocopy_from_numpy(np.unique(np.random.randint(300,size=100, dtype=dtype)))
v_nodes = F.zerocopy_from_numpy(np.random.randint(25, size=u_nodes.shape, dtype=dtype)) v_nodes = F.zerocopy_from_numpy(np.random.randint(25, size=u_nodes.shape, dtype=dtype))
g = dgl.graph((u_nodes, v_nodes)) g = dgl.graph((u_nodes, v_nodes)).to(F.ctx())
(U, V, EID) = (0, 1, 2) (U, V, EID) = (0, 1, 2)
......
Subproject commit a3ee304a1f8e22f278df10600df2e4b333012592 Subproject commit cdaa9558a85e45d849016e5fe7b6e4ee79113f95
Subproject commit 0ef5c509856e12cc408f0f00ed586b4c5b1a155c Subproject commit 6a3078c64cab0e2f276340fa5dcafa0d758ed890
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment