Commit 2a51f5e9 authored by ys-2020's avatar ys-2020
Browse files

[Minor] fixed CUDA kernel launching bug

parent 3a6dfc39
...@@ -25,6 +25,10 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -25,6 +25,10 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int j_factors1 = ((OC + 128 - 1) / 128); int j_factors1 = ((OC + 128 - 1) / 128);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8]; half A_shared_warp[8];
half B_shared_warp[32]; half B_shared_warp[32];
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) { for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
...@@ -36,20 +40,19 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -36,20 +40,19 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
static constexpr int row_stride_warp = 32 * 8 / 32; static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 128; static constexpr int row_stride = 2 * 32 * 8 / 128;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
// TODO: Haotian: blockIdx.y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx.y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M; // bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A half* A_ptr = A
+ (((int)blockIdx.y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8; + (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 2 + ((int)threadIdx.y) * (OC / 8) * 2
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8) + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ (((int)blockIdx.y) % j_factors1) * (128 / 8) + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1; + (((int)threadIdx.x) % (128 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8) + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
...@@ -62,26 +65,26 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -62,26 +65,26 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
+ (((int)threadIdx.x) % (128 / 8)) * 8; + (((int)threadIdx.x) % (128 / 8)) * 8;
int* zeros_ptr = zeros int* zeros_ptr = zeros
+ (((int)blockIdx.y) % j_factors1) * (128 / 8) + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ ((int)threadIdx.x) % (128 / 8); + ((int)threadIdx.x) % (128 / 8);
half* scaling_factors_ptr = scaling_factors half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx.y) % j_factors1) * (128) + (((int)blockIdx_y) % j_factors1) * (128)
+ (((int)threadIdx.x) % (128 / 8)) * 8; + (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C half* C_ptr = C
+ blockIdx.z * M * OC // blockIdz.x -> split_k dim + blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx.y) % j_factors1) * 128 + (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64 + ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2; + (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros // preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * 32 + blockIdx.z >= IC) k_bound -= 1; if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx.z; int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads(); __syncthreads();
// TODO: Haotian: blockIdx.y / j_factors1 in A loading to support bsz > 16 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag) if (ld_A_flag)
{ {
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
...@@ -96,7 +99,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -96,7 +99,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / 128 * (OC)); uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / 128 * (OC));
/* /*
if (blockIdx.z == 0 && blockIdx.y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
} }
*/ */
...@@ -104,12 +107,11 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -104,12 +107,11 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16 // B: 32 x 136 (128+8) float16
// each warp: 32 x 4 // each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx.y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N) // row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
...@@ -127,7 +129,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -127,7 +129,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/* /*
if (ax0_ax1_fused_0 == 0 && blockIdx.z == 0 && blockIdx.y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
} }
*/ */
...@@ -194,7 +196,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli ...@@ -194,7 +196,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
// TODO: Shang: Hoist loop invariance. // TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) { for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx.y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M) if (row_offset < M)
{ {
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
...@@ -231,15 +233,13 @@ torch::Tensor gemm_forward_cuda( ...@@ -231,15 +233,13 @@ torch::Tensor gemm_forward_cuda(
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
// blockIdx.x: i_factors[0] * j_factors[0]
// blockIdx.y: i_factors[1] * j_factors[1]
if (num_out_channels % 128 != 0) if (num_out_channels % 128 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 128"); throw std::invalid_argument("OC is not multiple of cta_N = 128");
if (num_out_channels % 8 != 0) if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8"); throw std::invalid_argument("OC is not multiple of pack_num = 8");
int j_factors1 = num_out_channels / 128 / 1; int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks(1, (num_out_feats + 16 - 1) / 16 * j_factors1, split_k_iters); dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment