Unverified Commit bacf2ab4 authored by paoxiaode's avatar paoxiaode Committed by GitHub
Browse files

change the curandState and launch dimension of CSRRowwiseSample kernel (#3990)



* Change the curand_init parameter

* Change the curand_init parameter

* commit

* commit

* change the curandState and launch dim of CSRRowwiseSample kernel

* commit

* keep  _CSRRowWiseSampleReplaceKernel in sync
Co-authored-by: default avatarnv-dlasalle <63612878+nv-dlasalle@users.noreply.github.com>
parent e0e8736f
...@@ -21,7 +21,7 @@ namespace impl { ...@@ -21,7 +21,7 @@ namespace impl {
namespace { namespace {
constexpr int WARP_SIZE = 32; constexpr int CTA_SIZE = 128;
/** /**
* @brief Compute the size of each row in the sampled CSR, without replacement. * @brief Compute the size of each row in the sampled CSR, without replacement.
...@@ -97,7 +97,7 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel( ...@@ -97,7 +97,7 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
* without replacement. * without replacement.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam BLOCK_WARPS The number of rows each thread block runs in parallel. * @tparam BLOCK_CTAS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock. * @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use. * @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
...@@ -111,7 +111,7 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel( ...@@ -111,7 +111,7 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
* @param out_cols The columns of the output COO (output). * @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_idxs The data array of the output COO (output).
*/ */
template<typename IdType, int BLOCK_WARPS, int TILE_SIZE> template<typename IdType, int BLOCK_CTAS, int TILE_SIZE>
__global__ void _CSRRowWiseSampleKernel( __global__ void _CSRRowWiseSampleKernel(
const uint64_t rand_seed, const uint64_t rand_seed,
const int64_t num_picks, const int64_t num_picks,
...@@ -125,13 +125,12 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -125,13 +125,12 @@ __global__ void _CSRRowWiseSampleKernel(
IdType * const out_cols, IdType * const out_cols,
IdType * const out_idxs) { IdType * const out_idxs) {
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == WARP_SIZE); assert(blockDim.x == CTA_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y; int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows); const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
curandState rng; curandStatePhilox4_32_10_t rng;
curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng); curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng);
while (out_row < last_row) { while (out_row < last_row) {
...@@ -144,7 +143,7 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -144,7 +143,7 @@ __global__ void _CSRRowWiseSampleKernel(
if (deg <= num_picks) { if (deg <= num_picks) {
// just copy row // just copy row
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
const IdType in_idx = in_row_start+idx; const IdType in_idx = in_row_start+idx;
out_rows[out_row_start+idx] = row; out_rows[out_row_start+idx] = row;
out_cols[out_row_start+idx] = in_index[in_idx]; out_cols[out_row_start+idx] = in_index[in_idx];
...@@ -152,12 +151,12 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -152,12 +151,12 @@ __global__ void _CSRRowWiseSampleKernel(
} }
} else { } else {
// generate permutation list via reservoir algorithm // generate permutation list via reservoir algorithm
for (int idx = threadIdx.x; idx < num_picks; idx+=WARP_SIZE) { for (int idx = threadIdx.x; idx < num_picks; idx+=CTA_SIZE) {
out_idxs[out_row_start+idx] = idx; out_idxs[out_row_start+idx] = idx;
} }
__syncwarp(); __syncthreads();
for (int idx = num_picks+threadIdx.x; idx < deg; idx+=WARP_SIZE) { for (int idx = num_picks+threadIdx.x; idx < deg; idx+=CTA_SIZE) {
const int num = curand(&rng)%(idx+1); const int num = curand(&rng)%(idx+1);
if (num < num_picks) { if (num < num_picks) {
// use max so as to achieve the replacement order the serial // use max so as to achieve the replacement order the serial
...@@ -165,10 +164,10 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -165,10 +164,10 @@ __global__ void _CSRRowWiseSampleKernel(
AtomicMax(out_idxs+out_row_start+num, idx); AtomicMax(out_idxs+out_row_start+num, idx);
} }
} }
__syncwarp(); __syncthreads();
// copy permutation over // copy permutation over
for (int idx = threadIdx.x; idx < num_picks; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < num_picks; idx += CTA_SIZE) {
const IdType perm_idx = out_idxs[out_row_start+idx]+in_row_start; const IdType perm_idx = out_idxs[out_row_start+idx]+in_row_start;
out_rows[out_row_start+idx] = row; out_rows[out_row_start+idx] = row;
out_cols[out_row_start+idx] = in_index[perm_idx]; out_cols[out_row_start+idx] = in_index[perm_idx];
...@@ -178,7 +177,7 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -178,7 +177,7 @@ __global__ void _CSRRowWiseSampleKernel(
} }
} }
out_row += BLOCK_WARPS; out_row += BLOCK_CTAS;
} }
} }
...@@ -187,7 +186,7 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -187,7 +186,7 @@ __global__ void _CSRRowWiseSampleKernel(
* with replacement. * with replacement.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam BLOCK_WARPS The number of rows each thread block runs in parallel. * @tparam BLOCK_CTAS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock. * @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use. * @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
...@@ -201,7 +200,7 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -201,7 +200,7 @@ __global__ void _CSRRowWiseSampleKernel(
* @param out_cols The columns of the output COO (output). * @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_idxs The data array of the output COO (output).
*/ */
template<typename IdType, int BLOCK_WARPS, int TILE_SIZE> template<typename IdType, int BLOCK_CTAS, int TILE_SIZE>
__global__ void _CSRRowWiseSampleReplaceKernel( __global__ void _CSRRowWiseSampleReplaceKernel(
const uint64_t rand_seed, const uint64_t rand_seed,
const int64_t num_picks, const int64_t num_picks,
...@@ -215,12 +214,12 @@ __global__ void _CSRRowWiseSampleReplaceKernel( ...@@ -215,12 +214,12 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
IdType * const out_cols, IdType * const out_cols,
IdType * const out_idxs) { IdType * const out_idxs) {
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == WARP_SIZE); assert(blockDim.x == CTA_SIZE);
int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y; int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows); const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
curandState rng; curandStatePhilox4_32_10_t rng;
curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng); curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng);
while (out_row < last_row) { while (out_row < last_row) {
...@@ -233,7 +232,7 @@ __global__ void _CSRRowWiseSampleReplaceKernel( ...@@ -233,7 +232,7 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
if (deg > 0) { if (deg > 0) {
// each thread then blindly copies in rows only if deg > 0. // each thread then blindly copies in rows only if deg > 0.
for (int idx = threadIdx.x; idx < num_picks; idx += blockDim.x) { for (int idx = threadIdx.x; idx < num_picks; idx += CTA_SIZE) {
const int64_t edge = curand(&rng) % deg; const int64_t edge = curand(&rng) % deg;
const int64_t out_idx = out_row_start+idx; const int64_t out_idx = out_row_start+idx;
out_rows[out_idx] = row; out_rows[out_idx] = row;
...@@ -241,7 +240,7 @@ __global__ void _CSRRowWiseSampleReplaceKernel( ...@@ -241,7 +240,7 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
out_idxs[out_idx] = data ? data[in_row_start+edge] : in_row_start+edge; out_idxs[out_idx] = data ? data[in_row_start+edge] : in_row_start+edge;
} }
} }
out_row += BLOCK_WARPS; out_row += BLOCK_CTAS;
} }
} }
...@@ -327,12 +326,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, ...@@ -327,12 +326,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
// select edges // select edges
if (replace) { if (replace) {
constexpr int BLOCK_WARPS = 128/WARP_SIZE; constexpr int BLOCK_CTAS = 128/CTA_SIZE;
// the number of rows each thread block will cover // the number of rows each thread block will cover
constexpr int TILE_SIZE = BLOCK_WARPS*16; constexpr int TILE_SIZE = BLOCK_CTAS;
const dim3 block(WARP_SIZE, BLOCK_WARPS); const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE); const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE);
_CSRRowWiseSampleReplaceKernel<IdType, BLOCK_WARPS, TILE_SIZE><<<grid, block, 0, stream>>>( _CSRRowWiseSampleReplaceKernel<IdType, BLOCK_CTAS, TILE_SIZE><<<grid, block, 0, stream>>>(
random_seed, random_seed,
num_picks, num_picks,
num_rows, num_rows,
...@@ -345,12 +344,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, ...@@ -345,12 +344,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
out_cols, out_cols,
out_idxs); out_idxs);
} else { } else {
constexpr int BLOCK_WARPS = 128/WARP_SIZE; constexpr int BLOCK_CTAS = 128/CTA_SIZE;
// the number of rows each thread block will cover // the number of rows each thread block will cover
constexpr int TILE_SIZE = BLOCK_WARPS*16; constexpr int TILE_SIZE = BLOCK_CTAS;
const dim3 block(WARP_SIZE, BLOCK_WARPS); const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE); const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE);
_CSRRowWiseSampleKernel<IdType, BLOCK_WARPS, TILE_SIZE><<<grid, block, 0, stream>>>( _CSRRowWiseSampleKernel<IdType, BLOCK_CTAS, TILE_SIZE><<<grid, block, 0, stream>>>(
random_seed, random_seed,
num_picks, num_picks,
num_rows, num_rows,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment