Unverified Commit acf98b5c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix the out-of-bounds access in the C+T+dbias kernel (#28)


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent 85e0373f
...@@ -121,7 +121,7 @@ cast_transpose_dbias_kernel(const Param param, ...@@ -121,7 +121,7 @@ cast_transpose_dbias_kernel(const Param param,
extern __shared__ char scratch[]; extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP); // const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
...@@ -262,7 +262,7 @@ cast_transpose_dbias_kernel_notaligned(const Param param, ...@@ -262,7 +262,7 @@ cast_transpose_dbias_kernel_notaligned(const Param param,
extern __shared__ char scratch[]; extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP); (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
...@@ -399,7 +399,9 @@ cast_transpose_dbias_kernel_notaligned(const Param param, ...@@ -399,7 +399,9 @@ cast_transpose_dbias_kernel_notaligned(const Param param,
} }
} }
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); if (my_id_in_warp < tile_length) {
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
} }
/* warp tile amax reduce*/ /* warp tile amax reduce*/
...@@ -630,7 +632,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param, ...@@ -630,7 +632,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
extern __shared__ char scratch[]; extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP); // const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
...@@ -791,7 +793,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param, ...@@ -791,7 +793,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
extern __shared__ char scratch[]; extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP); (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
...@@ -948,7 +950,9 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param, ...@@ -948,7 +950,9 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
} }
} }
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); if (my_id_in_warp < tile_length) {
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
} }
/* warp tile amax reduce*/ /* warp tile amax reduce*/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment