// Copyright (c) OpenMMLab. All rights reserved. #include "common.h" #include #define BLOCKSIZE 256 namespace turbomind { __device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t value) { uint32_t old = *address; uint32_t assumed; do { assumed = old; uint32_t tmp = (assumed & ~(0xfu << (index * 4u))) | (value << (index * 4u)); old = atomicCAS(address, assumed, tmp); } while (assumed != old); } __device__ uint32_t read_u4(const uint32_t* address, uint32_t index) { return (*address >> (index * 4u)) & 0xfu; } template __global__ void permute_u4(uint* dst, const uint* src, Array dims) { constexpr int N = sizeof...(Ds); size_t count = 1; PRAGMA_UNROLL for (int i = 0; i < N; ++i) { count *= dims[i]; } constexpr int order[] = {Ds...}; for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) { int indices[N]{}; PRAGMA_UNROLL for (int j = N - 1, ii = i; j >= 0; --j) { indices[j] = ii % dims[j]; ii /= dims[j]; } auto data = read_u4(src + i / 8, i % 8); int index = 0; PRAGMA_UNROLL for (int j = N - 1, stride = 1; j >= 0; --j) { index += indices[order[j]] * stride; stride *= dims[order[j]]; } atomic_assign_u4(dst + index / 8, index % 8, data); } } void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st) { // permutation for [k/8, m] layout Array shape{k / 32, 2, 2, m / 32, 2, 2, 8, 2, 2, 2}; // |warp| lane | 2x2 | a0-7 | permute_u4<0, 3, 6, 8, 9, 1, 4, 7, 2, 5><<<512, 512, 0, st>>>(dst, src, shape); } void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st) { // permutation for [k, m/8] layout Array shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4}; // |warp| lane | 2x2 | a0-7 | //permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape); permute_u4<0, 1, 2, 3, 4, 5, 6, 7, 8, 9><<<512, 512, 0, st>>>(dst, src, shape); } void reformat_s4_k_m8_tarnsw4(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st) { // permutation for [k, m/8] layout Array shape{1, k / 8, 2, 2, 2, 1, m / 8, 2, 2, 2}; // 0123456-->4,6,7,5,0,3,1,2 //permute_u4<4, 6, 7, 5, 0, 3, 1, 2><<<512, 512, 0, st>>>(dst, src, shape); permute_u4<5, 6, 8, 9, 7, 0, 1, 4, 2, 3><<<512, 512, 0, st>>>(dst, src, shape); } __global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count) { for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) { dst[i] = dequantize_s4_to_fp16x2_v2(src[i]); } } __global__ void merge_Q(half2* Q, const half* scales, const half* zeros, int count) { for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) { if (TURBOMIND_S4_DEQUANT_USE_FMA) { // dequant via HFMA2 has numerical statbility issue Q[i] = __halves2half2(-zeros[i] * scales[i], scales[i]); } else { Q[i] = __halves2half2(zeros[i], scales[i]); } } } void convert_s4_k_m8(uint32_t* A_dst, half2* Q_dst, half* workspace, const uint32_t* A_src, const half* scales, const uint32_t* qzeros, int m, int k, int group_size, cudaStream_t st) { dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8); merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m); reformat_s4_k_m8(A_dst, A_src, m, k, st); } void convert_s4_k_m8_(uint32_t* A_dst, half2* Q_dst, half* workspace, const uint32_t* A_src, const half* scales, const uint32_t* qzeros, int m, int k, int group_size, cudaStream_t st) { dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8); merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m); reformat_s4_k_m8_tarnsw4(A_dst, A_src, m, k, st); } void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st) { Array shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2}; // dequant transpose quant // 0123456 -> 0123564 -> 0135642 -> 0135264 permute_u4<0, 1, 3, 5, 2, 6, 4><<<512, 512, 0, st>>>(dst, src, shape); } // [2, k, m/8] -> [k, m/8, 2] void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st) { Array shape{2, k, m / 8, 2, 2, 2}; // dequant transpose quant // 012345 -> 012453 -> 124530 -> 124053 permute_u4<1, 2, 4, 0, 5, 3><<<512, 512, 0, st>>>(dst, src, shape); } __global__ void dequantize_s4_kernel(uint4* dst, const uint* src, size_t count) { for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) { dst[i] = dequantize_s4_to_fp16x2(src[i]); } } void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st) { dequantize_s4_kernel<<<512, 512>>>(dst, src, count); } __global__ void dequant_kernel(int num_kernels,half* weight ,const half2* zeros_and_scales,int k,int n,int group_size) { int id = blockIdx.x * blockDim.x + threadIdx.x; if(id >= num_kernels) return; int j=id%n; int i=id/n; half x=zeros_and_scales[i/group_size*n+j].data[0]; half y= zeros_and_scales[i/group_size*n+j].data[1]; float tmp=(weight[id]-x)*y; weight[id]=__float2half(tmp); } __global__ void dequant_kernel_colmajor(int num_kernels,half* weight ,const half2* zeros_and_scales,int k,int n,int group_size) { int id = blockIdx.x * blockDim.x + threadIdx.x; if(id >= num_kernels) return; int j=id/group_size; half x=zeros_and_scales[j].data[0]; half y= zeros_and_scales[j].data[1]; float tmp=(weight[id]-x)*y; weight[id]=__float2half(tmp); } void dequant_w4_gemm(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size) { dequantize_s4_offset_64<<<256, 256, 0, stream>>>((uint4*)output, weight, k * n / 8); int num_kernels=k*n; dequant_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,zeros_and_scales,k,n,group_size); } void dequant_w4_gemm_colmajor(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size) { dequantize_s4_offset_64<<<256, 256, 0, stream>>>((uint4*)output, weight, k * n / 8); int num_kernels=k*n; dequant_kernel_colmajor<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,zeros_and_scales,k,n,group_size); } __global__ void FusedSiluActivation_kernel(int num_kernels,half* output ,const uint32_t* src,int m,int n) { int id = blockIdx.x * blockDim.x + threadIdx.x; if(id >= num_kernels) return; auto data = ((half2*)src)[id]; float x= __half2float(data.data[0]); float y= __half2float(data.data[1]); float silu=x / (1.f + __expf(-x))*y; output[id]=__float2half(silu); } __global__ void assign_kernel(int num_kernels,half* output ,const half* src,int m,int n) { int id = blockIdx.x * blockDim.x + threadIdx.x; if(id >= num_kernels) return; output[id]=src[id]; } void addFusedSiluActivation(cudaStream_t stream,half* output, const half* src,int m,int n,int type) { int num_kernels=m*n; switch (type) { case 0: assign_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,src,m,n); break; case 1: FusedSiluActivation_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(int(num_kernels/2),output,(const uint32_t*)src,m,n); break; default: return; } } template __global__ void input_padding_kernel(int num_kernels,T* output,const T* input,int m,int k,int group_size,int count) { int id = blockIdx.x * blockDim.x + threadIdx.x; if(id >= num_kernels) return; int j=id%(k+count*group_size); int i=id/(k+count*group_size); if(j void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount) { //input的size是[m,k],output的size是[m,n+group_size] // int num_kernels=m*(k+pad_groupcount*group_size); input_padding_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels, output,input,m,k,group_size,pad_groupcount); } #define INSTANTIATEINPUTPADING(T) \ template void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount); INSTANTIATEINPUTPADING(__half) } // namespace turbomind