Unverified Commit 8c32b08a authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel] Fix awq error when n is not divisable by 128 (#13227)

parent 41088695
......@@ -334,7 +334,7 @@ __global__ void __launch_bounds__(64)
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int ax1_0_1 = 0; ax1_0_1 < (N / 32); ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment