"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "60b3b5a2ad8e69f4ccb81d661a36d6f80cc63ed3"
Unverified Commit 0cadd65f authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix potential overflow (#102)

parent 18712d00
...@@ -86,7 +86,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i ...@@ -86,7 +86,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
+ (((int)threadIdx.x) % (128 / 8)) * 8; + (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim + static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128 + (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64 + ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2; + (((int)threadIdx.x) % 4) * 2;
...@@ -314,7 +314,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in ...@@ -314,7 +314,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
+ (((int)threadIdx.x) % (64 / 8)) * 8; + (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim + static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64 + (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32 + ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2; + (((int)threadIdx.x) % 4) * 2;
...@@ -561,7 +561,7 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s ...@@ -561,7 +561,7 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s
// Haotian: TBD, check, May 29 11:46 AM PST // Haotian: TBD, check, May 29 11:46 AM PST
half* C_ptr = C half* C_ptr = C
+ blockIdx_z * M * OC // blockIdx_z -> split_k dim + static_cast<long long>(blockIdx_z) * M * OC // blockIdx_z -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64 + (((int)blockIdx_y) % j_factors1) * 64
+ (((int)threadIdx.y) / 2) * 32 + (((int)threadIdx.y) / 2) * 32
+ (((int)threadIdx.x) % 4) * 2; + (((int)threadIdx.x) % 4) * 2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment