Unverified Commit df69100c authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[Common] Fix long compile time in padding.cu on arch 75 (#2562)



* Fix long compile time in padding.cu
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a9767407
...@@ -94,7 +94,6 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP ...@@ -94,7 +94,6 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
#pragma unroll #pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) { for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2; const int row = tile_row + i1 * nvec + i2;
size_t row_offset = static_cast<size_t>(row) * row_length;
const int col = tile_col + j1 * nvec; const int col = tile_col + j1 * nvec;
Vec local_input; Vec local_input;
Vec local_output; Vec local_output;
...@@ -102,7 +101,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP ...@@ -102,7 +101,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
if (row < num_rows) { if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) { for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row_offset + col + j2]; local_input.data.elt[j2] = input[static_cast<size_t>(row) * row_length + col + j2];
} }
} }
} }
...@@ -113,14 +112,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP ...@@ -113,14 +112,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
if (row < num_rows) { if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) { for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
output[row_offset + col + j2] = local_output.data.elt[j2]; output[static_cast<size_t>(row) * row_length + col + j2] = local_output.data.elt[j2];
} }
} }
} else if (row < padded_num_rows) { } else if (row < padded_num_rows) {
// padding // padding
for (int j2 = 0; j2 < nvec; ++j2) { for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
output[row_offset + col + j2] = local_zero; output[static_cast<size_t>(row) * row_length + col + j2] = local_zero;
} }
} }
} }
...@@ -179,7 +178,6 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult ...@@ -179,7 +178,6 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
#pragma unroll #pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) { for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2; const int row = tile_row + i1 * nvec + i2;
size_t row_offset = static_cast<size_t>(row) * row_length;
const int col = tile_col + j1 * nvec; const int col = tile_col + j1 * nvec;
Vec local_input; Vec local_input;
Vec local_output; Vec local_output;
...@@ -187,7 +185,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult ...@@ -187,7 +185,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
if (row < num_rows) { if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) { for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row_offset + col + j2]; local_input.data.elt[j2] = input[static_cast<size_t>(row) * row_length + col + j2];
} }
} }
} }
...@@ -198,7 +196,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult ...@@ -198,7 +196,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
if (row < num_rows) { if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) { for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
output[row_offset + col + j2] = local_output.data.elt[j2]; output[static_cast<size_t>(row) * row_length + col + j2] = local_output.data.elt[j2];
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment