"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fac761694ab084256fb21dc8bf3e281b6bd8e8d8"
Unverified Commit f7ce2671 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[bugfix] Fix curand_init() calls in rowwise sampling (#3196)



* Split out separate generators for each thread

* Amortize cost of curand_init

* Improve readability
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent c40bbf4f
......@@ -97,7 +97,8 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
* without replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam BLOCK_ROWS The number of rows covered by each threadblock.
* @tparam BLOCK_WARPS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
......@@ -110,7 +111,7 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
*/
template<typename IdType, int BLOCK_ROWS>
template<typename IdType, int BLOCK_WARPS, int TILE_SIZE>
__global__ void _CSRRowWiseSampleKernel(
const uint64_t rand_seed,
const int64_t num_picks,
......@@ -125,20 +126,15 @@ __global__ void _CSRRowWiseSampleKernel(
IdType * const out_idxs) {
// we assign one warp per row
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_ROWS);
// we need one state per 256 threads
constexpr int NUM_RNG = ((WARP_SIZE*BLOCK_ROWS)+255)/256;
__shared__ curandState rng_array[NUM_RNG];
assert(blockDim.x >= NUM_RNG);
if (threadIdx.y == 0 && threadIdx.x < NUM_RNG) {
curand_init(rand_seed, 0, threadIdx.x, rng_array+threadIdx.x);
}
__syncthreads();
curandState * const rng = rng_array+((threadIdx.x+WARP_SIZE*threadIdx.y)/256);
assert(blockDim.y == BLOCK_WARPS);
int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
curandState rng;
curand_init(rand_seed*gridDim.x+blockIdx.x, threadIdx.y*WARP_SIZE+threadIdx.x, 0, &rng);
int64_t out_row = blockIdx.x*BLOCK_ROWS+threadIdx.y;
while (out_row < num_rows) {
while (out_row < last_row) {
const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row];
......@@ -162,7 +158,7 @@ __global__ void _CSRRowWiseSampleKernel(
__syncwarp();
for (int idx = num_picks+threadIdx.x; idx < deg; idx+=WARP_SIZE) {
const int num = curand(rng)%(idx+1);
const int num = curand(&rng)%(idx+1);
if (num < num_picks) {
// use max so as to achieve the replacement order the serial
// algorithm would have
......@@ -182,7 +178,7 @@ __global__ void _CSRRowWiseSampleKernel(
}
}
out_row += gridDim.x*BLOCK_ROWS;
out_row += BLOCK_WARPS;
}
}
......@@ -191,7 +187,8 @@ __global__ void _CSRRowWiseSampleKernel(
* with replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam BLOCK_ROWS The number of rows covered by each threadblock.
* @tparam BLOCK_WARPS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
......@@ -204,7 +201,7 @@ __global__ void _CSRRowWiseSampleKernel(
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
*/
template<typename IdType, int BLOCK_ROWS>
template<typename IdType, int BLOCK_WARPS, int TILE_SIZE>
__global__ void _CSRRowWiseSampleReplaceKernel(
const uint64_t rand_seed,
const int64_t num_picks,
......@@ -220,18 +217,13 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
// we assign one warp per row
assert(blockDim.x == WARP_SIZE);
// we need one state per 256 threads
constexpr int NUM_RNG = ((WARP_SIZE*BLOCK_ROWS)+255)/256;
__shared__ curandState rng_array[NUM_RNG];
assert(blockDim.x >= NUM_RNG);
if (threadIdx.y == 0 && threadIdx.x < NUM_RNG) {
curand_init(rand_seed, 0, threadIdx.x, rng_array+threadIdx.x);
}
__syncthreads();
curandState * const rng = rng_array+((threadIdx.x+WARP_SIZE*threadIdx.y)/256);
int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
curandState rng;
curand_init(rand_seed*gridDim.x+blockIdx.x, threadIdx.y*WARP_SIZE+threadIdx.x, 0, &rng);
int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y;
while (out_row < num_rows) {
while (out_row < last_row) {
const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row];
......@@ -239,15 +231,17 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
const int64_t deg = in_ptr[row+1] - in_row_start;
// each thread then blindly copies in rows
for (int idx = threadIdx.x; idx < num_picks; idx += blockDim.x) {
const int64_t edge = curand(rng) % deg;
const int64_t out_idx = out_row_start+idx;
out_rows[out_idx] = row;
out_cols[out_idx] = in_index[in_row_start+edge];
out_idxs[out_idx] = data ? data[in_row_start+edge] : in_row_start+edge;
if (deg > 0) {
// each thread then blindly copies in rows only if deg > 0.
for (int idx = threadIdx.x; idx < num_picks; idx += blockDim.x) {
const int64_t edge = curand(&rng) % deg;
const int64_t out_idx = out_row_start+idx;
out_rows[out_idx] = row;
out_cols[out_idx] = in_index[in_row_start+edge];
out_idxs[out_idx] = data ? data[in_row_start+edge] : in_row_start+edge;
}
}
out_row += gridDim.x*blockDim.y;
out_row += BLOCK_WARPS;
}
}
......@@ -333,10 +327,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
// select edges
if (replace) {
constexpr int BLOCK_ROWS = 128/WARP_SIZE;
const dim3 block(WARP_SIZE, BLOCK_ROWS);
const dim3 grid((num_rows+block.y-1)/block.y);
_CSRRowWiseSampleReplaceKernel<IdType, BLOCK_ROWS><<<grid, block, 0, stream>>>(
constexpr int BLOCK_WARPS = 128/WARP_SIZE;
// the number of rows each thread block will cover
constexpr int TILE_SIZE = BLOCK_WARPS*16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE);
_CSRRowWiseSampleReplaceKernel<IdType, BLOCK_WARPS, TILE_SIZE><<<grid, block, 0, stream>>>(
random_seed,
num_picks,
num_rows,
......@@ -349,10 +345,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
out_cols,
out_idxs);
} else {
constexpr int BLOCK_ROWS = 128/WARP_SIZE;
const dim3 block(WARP_SIZE, BLOCK_ROWS);
const dim3 grid((num_rows+block.y-1)/block.y);
_CSRRowWiseSampleKernel<IdType, BLOCK_ROWS><<<grid, block, 0, stream>>>(
constexpr int BLOCK_WARPS = 128/WARP_SIZE;
// the number of rows each thread block will cover
constexpr int TILE_SIZE = BLOCK_WARPS*16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE);
_CSRRowWiseSampleKernel<IdType, BLOCK_WARPS, TILE_SIZE><<<grid, block, 0, stream>>>(
random_seed,
num_picks,
num_rows,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment