Unverified Commit eec219ab authored by paoxiaode's avatar paoxiaode Committed by GitHub
Browse files

Change the parameter of curand_init (#3794)



* Change the curand_init parameter

* Change the curand_init parameter

* commit

* commit
Co-authored-by: default avatarnv-dlasalle <63612878+nv-dlasalle@users.noreply.github.com>
parent 520cef88
......@@ -10,6 +10,7 @@ contributor to the DGL project. We will put your name in the list below.
Contributors
------------
* [Minjie Wang](https://github.com/jermainewang) from AWS
* [Da Zheng](https://github.com/zheng-da) from AWS
* [Quan Gan](https://github.com/BarclayII) from AWS
......@@ -60,3 +61,4 @@ Contributors
* [Xin Yao](https://github.com/yaox12) from Nvidia
* [Abdurrahman Yasar](https://github.com/ayasar70) from Nvidia
* [Shaked Brody](https://github.com/shakedbr) from Technion
* [Jiahui Liu](https://github.com/paoxiaode) from Nvidia
......@@ -132,7 +132,7 @@ __global__ void _CSRRowWiseSampleKernel(
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
curandState rng;
curand_init(rand_seed*gridDim.x+blockIdx.x, threadIdx.y*WARP_SIZE+threadIdx.x, 0, &rng);
curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng);
while (out_row < last_row) {
const int64_t row = in_rows[out_row];
......@@ -221,7 +221,7 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x+1)*TILE_SIZE, num_rows);
curandState rng;
curand_init(rand_seed*gridDim.x+blockIdx.x, threadIdx.y*WARP_SIZE+threadIdx.x, 0, &rng);
curand_init((rand_seed*gridDim.x+blockIdx.x)*blockDim.y+threadIdx.y, threadIdx.x, 0, &rng);
while (out_row < last_row) {
const int64_t row = in_rows[out_row];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment