// Inspired by https://github.com/ankan-ban/llama_cu_awq /* @article{lin2023awq, title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ #include #include #include #include #include "gemv_cuda.h" #define VECTORIZE_FACTOR 8 #define Q_VECTORIZE_FACTOR 8 #define PACK_FACTOR 8 #define WARP_SIZE 32 // Reduce sum within the warp using the tree reduction algorithm. __device__ __forceinline__ float warp_reduce_sum(float sum) { #pragma unroll for(int i = 4; i >= 0; i--){ sum += __shfl_down_sync(0xffffffff, sum, 1<(zeros + oc_idx * zeros_w + packed_group_idx * 2); uint32_t packed_weights[4]; // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); // load scaling factors // g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups. float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]); float current_zeros = (float)((packed_zeros >> (threadIdx.x / 2 * 4)) & 0xF); int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; const float4* inputs_ptr = inputs + inputs_ptr_delta; // multiply 32 weights with 32 inputs #pragma unroll for (int ic_0 = 0; ic_0 < 4; ic_0++){ // iterate over different uint32_t packed_weights in this loop uint32_t current_packed_weight = packed_weights[ic_0]; half packed_inputs[PACK_FACTOR]; // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { *((float4*)packed_inputs) = *(inputs_ptr + ic_0); #pragma unroll for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ // iterate over 8 numbers packed within each uint32_t number float current_single_weight_fp = (float)(current_packed_weight & 0xF); float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); psum += dequantized_weight * __half2float(packed_inputs[ic_1]); current_packed_weight = current_packed_weight >> 4; } } } } psum = warp_reduce_sum(psum); if (threadIdx.x == 0) { outputs[oc_idx] = __float2half(psum); } } /* Computes GEMV (group_size = 128). Args: inputs: vector of shape [batch_size, IC]; weight: matrix of shape [OC, IC / 8]; output: vector of shape [OC]; zeros: matrix of shape [OC, IC / group_size / 8]; scaling_factors: matrix of shape [OC, IC / group_size]; Notes: One cannot infer group_size from the shape of scaling factors. the second dimension is rounded up to a multiple of PACK_FACTOR. */ __global__ void gemv_kernel_g128( const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs, const int IC, const int OC){ const int group_size = 128; float psum = 0; const int batch_idx = blockIdx.z; const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y; const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR; half* outputs = _outputs + batch_idx * OC; const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR); const int weight_w = IC / PACK_FACTOR; // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR); // consistent with input shape const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR; //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w); // tile size: 4 OC x 1024 IC per iter for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){ // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros. uint32_t packed_zeros = *(zeros + oc_idx * zeros_w + packed_group_idx); uint32_t packed_weights[4]; // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); // load scaling factors // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups. float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); float current_zeros = (float)((packed_zeros >> (threadIdx.x / 4 * 4)) & 0xF); int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; const float4* inputs_ptr = inputs + inputs_ptr_delta; // multiply 32 weights with 32 inputs #pragma unroll for (int ic_0 = 0; ic_0 < 4; ic_0++){ // iterate over different uint32_t packed_weights in this loop uint32_t current_packed_weight = packed_weights[ic_0]; half packed_inputs[PACK_FACTOR]; // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { *((float4*)packed_inputs) = *(inputs_ptr + ic_0); #pragma unroll for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ // iterate over 8 numbers packed within each uint32_t number float current_single_weight_fp = (float)(current_packed_weight & 0xF); float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); psum += dequantized_weight * __half2float(packed_inputs[ic_1]); current_packed_weight = current_packed_weight >> 4; } } } } psum = warp_reduce_sum(psum); if (threadIdx.x == 0) { outputs[oc_idx] = __float2half(psum); } } /* Computes GEMV (PyTorch interface). Args: _in_feats: tensor of shape [B, IC]; _kernel: int tensor of shape [OC, IC // 8]; _zeros: int tensor of shape [OC, IC // G // 8]; _scaling_factors: tensor of shape [OC, IC // G]; blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; Returns: out_feats: tensor of shape [B, OC]; */ torch::Tensor gemv_forward_cuda( torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); // int kernel_volume = _out_in_map.size(1); auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto kernel = reinterpret_cast(_kernel.data_ptr()); auto zeros = reinterpret_cast(_zeros.data_ptr()); auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); // auto out_in_map = _out_in_map.data_ptr(); auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); // kernel is [OC, IC] at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options); int num_out_feats = _out_feats.size(-2); int num_out_channels = _out_feats.size(-1); auto out_feats = reinterpret_cast(_out_feats.data_ptr()); int blockDim_z = num_out_feats; dim3 num_blocks(1, num_out_channels / 4, num_out_feats); dim3 num_threads(32, 4); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (group_size == 64) { gemv_kernel_g64<<>>( // pointers in_feats, kernel, zeros, scaling_factors, out_feats, // constants num_in_channels, num_out_channels ); } else if (group_size == 128) { gemv_kernel_g128<<>>( // pointers in_feats, kernel, zeros, scaling_factors, out_feats, // constants num_in_channels, num_out_channels ); } return _out_feats; ;}