Unverified Commit 697b52cb authored by 刘俊's avatar 刘俊 Committed by GitHub
Browse files

Fix overflow of padding/unpadding kernel (#2548)


Signed-off-by: default avatarfuyue.lj <fuyue.lj@antgroup.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 26c82db6
......@@ -94,6 +94,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
size_t row_offset = static_cast<size_t>(row) * row_length;
const int col = tile_col + j1 * nvec;
Vec local_input;
Vec local_output;
......@@ -101,7 +102,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
local_input.data.elt[j2] = input[row_offset + col + j2];
}
}
}
......@@ -112,14 +113,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_output.data.elt[j2];
output[row_offset + col + j2] = local_output.data.elt[j2];
}
}
} else if (row < padded_num_rows) {
// padding
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_zero;
output[row_offset + col + j2] = local_zero;
}
}
}
......@@ -178,6 +179,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
size_t row_offset = static_cast<size_t>(row) * row_length;
const int col = tile_col + j1 * nvec;
Vec local_input;
Vec local_output;
......@@ -185,7 +187,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
local_input.data.elt[j2] = input[row_offset + col + j2];
}
}
}
......@@ -196,7 +198,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_output.data.elt[j2];
output[row_offset + col + j2] = local_output.data.elt[j2];
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment