Unverified Commit 980dd4a2 authored by CHU Tianxiang's avatar CHU Tianxiang Committed by GitHub
Browse files

Fix overflow in awq kernel (#1295)


Co-authored-by: default avatar楚天翔 <tianxiang.ctx@alibaba-inc.com>
parent 82857368
...@@ -90,7 +90,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i ...@@ -90,7 +90,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
+ (((int)threadIdx.x) % (128 / 8)) * 8; + (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim + static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128 + (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64 + ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2; + (((int)threadIdx.x) % 4) * 2;
...@@ -323,7 +323,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in ...@@ -323,7 +323,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
+ (((int)threadIdx.x) % (64 / 8)) * 8; + (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim + static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64 + (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32 + ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2; + (((int)threadIdx.x) % 4) * 2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment