Commit 16c5fe16 authored by Casper's avatar Casper
Browse files

Add dequantization kernel

parent b5592bd6
......@@ -12,4 +12,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("gemmv2_forward_cuda", &gemmv2_forward_cuda, "Quantized v2 GEMM kernel.");
m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
m.def("dequantize_weights_cuda", &dequantize_weights_cuda, "Dequantize weights.");
}
\ No newline at end of file
......@@ -4,4 +4,8 @@ torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters);
torch::Tensor gemmv2_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters);
\ No newline at end of file
torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters);
// Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda.h#L9C1-L10C106
torch::Tensor dequantize_weights_cuda(torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters, int thx, int thy, bool dbg);
\ No newline at end of file
......@@ -15,6 +15,7 @@
#include "dequantize.cuh"
#include <cuda_fp16.h>
#include <c10/cuda/CUDAGuard.h>
#include <cublas_v2.h>
// Pack two half values.
......@@ -724,6 +725,140 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s
}
}
// Dequantization to fp16
// kernel
// Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda_gen.cu#L32C1-L32C1
__global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, // 4096x64 4096 rows 64 cols
half* __restrict__ scaling_factors, // 32x512 32 rows 512 cols
int* __restrict__ zeros, // 32x64 32 rows 64 cols
half* __restrict__ C, // 4096x512 4096 rows 512 cols
int G,
bool dbg)
{
int j_factors1 = 4;
int row_stride2 = 4;
int split_k_iters = 1;
static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];
half* B_shared_ptr2 = B_shared;
half B_shared_warp[32];
int OC = 512;
int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = blockIdx.y * blockDim.y + threadIdx.y;
int index1 = 8 * col + 8 * row * N; // + i (<8)
half* C_ptr2 = C + index1;
int index2 = col + row * N;
int* B_ptr2 = B + index2;
if (dbg) {
printf("\n-------- x %d - y %d --------\n", col, row);
printf("- %d-%d - N %d index1 %d \n", col, row, N, index2);
printf("- %d-%d - B %d \n", col, row, *B_ptr2);
}
int index3 = col + (int)(row / G) * N;
int* zeros_ptr2 = zeros + index3;
int index4 = 8 * col + (int)(row / G) * N * 8; // + i (<8)
half* scaling_factors_ptr2 = scaling_factors + index4;
if (dbg) {
printf("- %d-%d - zeros[%d] %d \n", col, row, index3, *zeros_ptr2);
printf("- %d-%d - N %d index4 %d \n", col, row, N, index4);
for (int i=0; i<8; ++i) {
printf("- %d-%d - scale[%d] %f \n", col, row, index4 + i, __half2float(*(scaling_factors_ptr2+i)));
}
}
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
int j=0;
uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j);
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
*(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16;
if (dbg) {
for (int i=0; i<8; ++i) {
printf("- %d-%d - B_shared_ptr2[%d] %f \n", col, row, i, __half2float(*(B_shared_ptr2 + i)) );
}
}
for (int i=0; i<8; ++i) {
*(C_ptr2 + i) = B_shared[i];
}
}
// Dequantization to fp16
// Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda_gen.cu#L935C1-L987C2
torch::Tensor dequantize_weights_cuda(
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy,
bool dbg)
{
int in_c = _kernel.size(0);
int qout_c = _kernel.size(1);
int out_c = qout_c * 8;
int G = in_c / _scaling_factors.size(0);
int x_thread = thx;
int y_thread = thy;
int x_blocks = 1;
int y_blocks = 1;
if (thx==0) {
x_thread = qout_c;
}
if (thy==0) {
y_thread = in_c;
}
int dbg_ = true;
if (thx==0 && thy==0) {
dbg_ = false;
x_thread = 8;
y_thread = 8;
x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8);
}
dbg = dbg && dbg_;
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); // row, col 4096x512
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_thread, y_thread); // col, row 64x4096
dequantize_weights<<<num_blocks, threads_per_block>>>(kernel, scaling_factors, zeros, de_kernel, G, dbg);
return _de_kernel;
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment