Unverified Commit f7ce2671 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[bugfix] Fix curand_init() calls in rowwise sampling (#3196)



* Split out separate generators for each thread

* Amortize cost of curand_init

* Improve readability
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent c40bbf4f
...@@ -97,7 +97,8 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel( ...@@ -97,7 +97,8 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
* without replacement. * without replacement.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam BLOCK_ROWS The number of rows covered by each threadblock. * @tparam BLOCK_WARPS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use. * @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick. * @param num_rows The number of rows to pick.
...@@ -110,7 +111,7 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel( ...@@ -110,7 +111,7 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
* @param out_cols The columns of the output COO (output). * @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_idxs The data array of the output COO (output).
*/ */
template<typename IdType, int BLOCK_ROWS> template<typename IdType, int BLOCK_WARPS, int TILE_SIZE>
__global__ void _CSRRowWiseSampleKernel( __global__ void _CSRRowWiseSampleKernel(
const uint64_t rand_seed, const uint64_t rand_seed,
const int64_t num_picks, const int64_t num_picks,
...@@ -125,20 +126,15 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -125,20 +126,15 @@ __global__ void _CSRRowWiseSampleKernel(
IdType * const out_idxs) { IdType * const out_idxs) {
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == WARP_SIZE); assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_ROWS); assert(blockDim.y == BLOCK_WARPS);
// we need one state per 256 threads int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y;
constexpr int NUM_RNG = ((WARP_SIZE*BLOCK_ROWS)+255)/256; const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
__shared__ curandState rng_array[NUM_RNG];
assert(blockDim.x >= NUM_RNG); curandState rng;
if (threadIdx.y == 0 && threadIdx.x < NUM_RNG) { curand_init(rand_seed*gridDim.x+blockIdx.x, threadIdx.y*WARP_SIZE+threadIdx.x, 0, &rng);
curand_init(rand_seed, 0, threadIdx.x, rng_array+threadIdx.x);
}
__syncthreads();
curandState * const rng = rng_array+((threadIdx.x+WARP_SIZE*threadIdx.y)/256);
int64_t out_row = blockIdx.x*BLOCK_ROWS+threadIdx.y; while (out_row < last_row) {
while (out_row < num_rows) {
const int64_t row = in_rows[out_row]; const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row]; const int64_t in_row_start = in_ptr[row];
...@@ -162,7 +158,7 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -162,7 +158,7 @@ __global__ void _CSRRowWiseSampleKernel(
__syncwarp(); __syncwarp();
for (int idx = num_picks+threadIdx.x; idx < deg; idx+=WARP_SIZE) { for (int idx = num_picks+threadIdx.x; idx < deg; idx+=WARP_SIZE) {
const int num = curand(rng)%(idx+1); const int num = curand(&rng)%(idx+1);
if (num < num_picks) { if (num < num_picks) {
// use max so as to achieve the replacement order the serial // use max so as to achieve the replacement order the serial
// algorithm would have // algorithm would have
...@@ -182,7 +178,7 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -182,7 +178,7 @@ __global__ void _CSRRowWiseSampleKernel(
} }
} }
out_row += gridDim.x*BLOCK_ROWS; out_row += BLOCK_WARPS;
} }
} }
...@@ -191,7 +187,8 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -191,7 +187,8 @@ __global__ void _CSRRowWiseSampleKernel(
* with replacement. * with replacement.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam BLOCK_ROWS The number of rows covered by each threadblock. * @tparam BLOCK_WARPS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use. * @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick. * @param num_rows The number of rows to pick.
...@@ -204,7 +201,7 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -204,7 +201,7 @@ __global__ void _CSRRowWiseSampleKernel(
* @param out_cols The columns of the output COO (output). * @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_idxs The data array of the output COO (output).
*/ */
template<typename IdType, int BLOCK_ROWS> template<typename IdType, int BLOCK_WARPS, int TILE_SIZE>
__global__ void _CSRRowWiseSampleReplaceKernel( __global__ void _CSRRowWiseSampleReplaceKernel(
const uint64_t rand_seed, const uint64_t rand_seed,
const int64_t num_picks, const int64_t num_picks,
...@@ -220,18 +217,13 @@ __global__ void _CSRRowWiseSampleReplaceKernel( ...@@ -220,18 +217,13 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == WARP_SIZE); assert(blockDim.x == WARP_SIZE);
// we need one state per 256 threads int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y;
constexpr int NUM_RNG = ((WARP_SIZE*BLOCK_ROWS)+255)/256; const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
__shared__ curandState rng_array[NUM_RNG];
assert(blockDim.x >= NUM_RNG); curandState rng;
if (threadIdx.y == 0 && threadIdx.x < NUM_RNG) { curand_init(rand_seed*gridDim.x+blockIdx.x, threadIdx.y*WARP_SIZE+threadIdx.x, 0, &rng);
curand_init(rand_seed, 0, threadIdx.x, rng_array+threadIdx.x);
}
__syncthreads();
curandState * const rng = rng_array+((threadIdx.x+WARP_SIZE*threadIdx.y)/256);
int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y; while (out_row < last_row) {
while (out_row < num_rows) {
const int64_t row = in_rows[out_row]; const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row]; const int64_t in_row_start = in_ptr[row];
...@@ -239,15 +231,17 @@ __global__ void _CSRRowWiseSampleReplaceKernel( ...@@ -239,15 +231,17 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
const int64_t deg = in_ptr[row+1] - in_row_start; const int64_t deg = in_ptr[row+1] - in_row_start;
// each thread then blindly copies in rows if (deg > 0) {
for (int idx = threadIdx.x; idx < num_picks; idx += blockDim.x) { // each thread then blindly copies in rows only if deg > 0.
const int64_t edge = curand(rng) % deg; for (int idx = threadIdx.x; idx < num_picks; idx += blockDim.x) {
const int64_t out_idx = out_row_start+idx; const int64_t edge = curand(&rng) % deg;
out_rows[out_idx] = row; const int64_t out_idx = out_row_start+idx;
out_cols[out_idx] = in_index[in_row_start+edge]; out_rows[out_idx] = row;
out_idxs[out_idx] = data ? data[in_row_start+edge] : in_row_start+edge; out_cols[out_idx] = in_index[in_row_start+edge];
out_idxs[out_idx] = data ? data[in_row_start+edge] : in_row_start+edge;
}
} }
out_row += gridDim.x*blockDim.y; out_row += BLOCK_WARPS;
} }
} }
...@@ -333,10 +327,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, ...@@ -333,10 +327,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
// select edges // select edges
if (replace) { if (replace) {
constexpr int BLOCK_ROWS = 128/WARP_SIZE; constexpr int BLOCK_WARPS = 128/WARP_SIZE;
const dim3 block(WARP_SIZE, BLOCK_ROWS); // the number of rows each thread block will cover
const dim3 grid((num_rows+block.y-1)/block.y); constexpr int TILE_SIZE = BLOCK_WARPS*16;
_CSRRowWiseSampleReplaceKernel<IdType, BLOCK_ROWS><<<grid, block, 0, stream>>>( const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE);
_CSRRowWiseSampleReplaceKernel<IdType, BLOCK_WARPS, TILE_SIZE><<<grid, block, 0, stream>>>(
random_seed, random_seed,
num_picks, num_picks,
num_rows, num_rows,
...@@ -349,10 +345,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, ...@@ -349,10 +345,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
out_cols, out_cols,
out_idxs); out_idxs);
} else { } else {
constexpr int BLOCK_ROWS = 128/WARP_SIZE; constexpr int BLOCK_WARPS = 128/WARP_SIZE;
const dim3 block(WARP_SIZE, BLOCK_ROWS); // the number of rows each thread block will cover
const dim3 grid((num_rows+block.y-1)/block.y); constexpr int TILE_SIZE = BLOCK_WARPS*16;
_CSRRowWiseSampleKernel<IdType, BLOCK_ROWS><<<grid, block, 0, stream>>>( const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE);
_CSRRowWiseSampleKernel<IdType, BLOCK_WARPS, TILE_SIZE><<<grid, block, 0, stream>>>(
random_seed, random_seed,
num_picks, num_picks,
num_rows, num_rows,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment