Unverified Commit bacf2ab4 authored by paoxiaode's avatar paoxiaode Committed by GitHub
Browse files

change the curandState and launch dimension of CSRRowwiseSample kernel (#3990)



* Change the curand_init parameter

* Change the curand_init parameter

* commit

* commit

* change the curandState and launch dim of CSRRowwiseSample kernel

* commit

* keep  _CSRRowWiseSampleReplaceKernel in sync
Co-authored-by: default avatarnv-dlasalle <63612878+nv-dlasalle@users.noreply.github.com>
parent e0e8736f
......@@ -21,7 +21,7 @@ namespace impl {
namespace {
constexpr int WARP_SIZE = 32;
constexpr int CTA_SIZE = 128;
/**
* @brief Compute the size of each row in the sampled CSR, without replacement.
......@@ -97,7 +97,7 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
* without replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam BLOCK_WARPS The number of rows each thread block runs in parallel.
* @tparam BLOCK_CTAS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
......@@ -111,7 +111,7 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
*/
template<typename IdType, int BLOCK_WARPS, int TILE_SIZE>
template<typename IdType, int BLOCK_CTAS, int TILE_SIZE>
__global__ void _CSRRowWiseSampleKernel(
const uint64_t rand_seed,
const int64_t num_picks,
......@@ -125,13 +125,12 @@ __global__ void _CSRRowWiseSampleKernel(
IdType * const out_cols,
IdType * const out_idxs) {
// we assign one warp per row
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
assert(blockDim.x == CTA_SIZE);
int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
curandState rng;
curandStatePhilox4_32_10_t rng;
curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng);
while (out_row < last_row) {
......@@ -144,7 +143,7 @@ __global__ void _CSRRowWiseSampleKernel(
if (deg <= num_picks) {
// just copy row
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
const IdType in_idx = in_row_start+idx;
out_rows[out_row_start+idx] = row;
out_cols[out_row_start+idx] = in_index[in_idx];
......@@ -152,12 +151,12 @@ __global__ void _CSRRowWiseSampleKernel(
}
} else {
// generate permutation list via reservoir algorithm
for (int idx = threadIdx.x; idx < num_picks; idx+=WARP_SIZE) {
for (int idx = threadIdx.x; idx < num_picks; idx+=CTA_SIZE) {
out_idxs[out_row_start+idx] = idx;
}
__syncwarp();
__syncthreads();
for (int idx = num_picks+threadIdx.x; idx < deg; idx+=WARP_SIZE) {
for (int idx = num_picks+threadIdx.x; idx < deg; idx+=CTA_SIZE) {
const int num = curand(&rng)%(idx+1);
if (num < num_picks) {
// use max so as to achieve the replacement order the serial
......@@ -165,10 +164,10 @@ __global__ void _CSRRowWiseSampleKernel(
AtomicMax(out_idxs+out_row_start+num, idx);
}
}
__syncwarp();
__syncthreads();
// copy permutation over
for (int idx = threadIdx.x; idx < num_picks; idx += WARP_SIZE) {
for (int idx = threadIdx.x; idx < num_picks; idx += CTA_SIZE) {
const IdType perm_idx = out_idxs[out_row_start+idx]+in_row_start;
out_rows[out_row_start+idx] = row;
out_cols[out_row_start+idx] = in_index[perm_idx];
......@@ -178,7 +177,7 @@ __global__ void _CSRRowWiseSampleKernel(
}
}
out_row += BLOCK_WARPS;
out_row += BLOCK_CTAS;
}
}
......@@ -187,7 +186,7 @@ __global__ void _CSRRowWiseSampleKernel(
* with replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam BLOCK_WARPS The number of rows each thread block runs in parallel.
* @tparam BLOCK_CTAS The number of rows each thread block runs in parallel.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
......@@ -201,7 +200,7 @@ __global__ void _CSRRowWiseSampleKernel(
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
*/
template<typename IdType, int BLOCK_WARPS, int TILE_SIZE>
template<typename IdType, int BLOCK_CTAS, int TILE_SIZE>
__global__ void _CSRRowWiseSampleReplaceKernel(
const uint64_t rand_seed,
const int64_t num_picks,
......@@ -215,12 +214,12 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
IdType * const out_cols,
IdType * const out_idxs) {
// we assign one warp per row
assert(blockDim.x == WARP_SIZE);
assert(blockDim.x == CTA_SIZE);
int64_t out_row = blockIdx.x*TILE_SIZE+threadIdx.y;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
curandState rng;
curandStatePhilox4_32_10_t rng;
curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng);
while (out_row < last_row) {
......@@ -233,7 +232,7 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
if (deg > 0) {
// each thread then blindly copies in rows only if deg > 0.
for (int idx = threadIdx.x; idx < num_picks; idx += blockDim.x) {
for (int idx = threadIdx.x; idx < num_picks; idx += CTA_SIZE) {
const int64_t edge = curand(&rng) % deg;
const int64_t out_idx = out_row_start+idx;
out_rows[out_idx] = row;
......@@ -241,7 +240,7 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
out_idxs[out_idx] = data ? data[in_row_start+edge] : in_row_start+edge;
}
}
out_row += BLOCK_WARPS;
out_row += BLOCK_CTAS;
}
}
......@@ -327,12 +326,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
// select edges
if (replace) {
constexpr int BLOCK_WARPS = 128/WARP_SIZE;
constexpr int BLOCK_CTAS = 128/CTA_SIZE;
// the number of rows each thread block will cover
constexpr int TILE_SIZE = BLOCK_WARPS*16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
constexpr int TILE_SIZE = BLOCK_CTAS;
const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE);
_CSRRowWiseSampleReplaceKernel<IdType, BLOCK_WARPS, TILE_SIZE><<<grid, block, 0, stream>>>(
_CSRRowWiseSampleReplaceKernel<IdType, BLOCK_CTAS, TILE_SIZE><<<grid, block, 0, stream>>>(
random_seed,
num_picks,
num_rows,
......@@ -345,12 +344,12 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
out_cols,
out_idxs);
} else {
constexpr int BLOCK_WARPS = 128/WARP_SIZE;
constexpr int BLOCK_CTAS = 128/CTA_SIZE;
// the number of rows each thread block will cover
constexpr int TILE_SIZE = BLOCK_WARPS*16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
constexpr int TILE_SIZE = BLOCK_CTAS;
const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((num_rows+TILE_SIZE-1)/TILE_SIZE);
_CSRRowWiseSampleKernel<IdType, BLOCK_WARPS, TILE_SIZE><<<grid, block, 0, stream>>>(
_CSRRowWiseSampleKernel<IdType, BLOCK_CTAS, TILE_SIZE><<<grid, block, 0, stream>>>(
random_seed,
num_picks,
num_rows,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment