Unverified Commit 0cadd65f authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix potential overflow (#102)

parent 18712d00
......@@ -86,7 +86,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
+ (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2;
......@@ -314,7 +314,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
......@@ -561,7 +561,7 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s
// Haotian: TBD, check, May 29 11:46 AM PST
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdx_z -> split_k dim
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdx_z -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ (((int)threadIdx.y) / 2) * 32
+ (((int)threadIdx.x) % 4) * 2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment