Commit e90433a0 authored by Casper's avatar Casper
Browse files

Initial commit

parent 5440c0aa
// Inspired by https://github.com/ankan-ban/llama_cu_awq
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemv_cuda.h"
#define VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8
#define PACK_FACTOR 8
#define WARP_SIZE 32
// Reduce sum within the warp using the tree reduction algorithm.
__device__ __forceinline__ float warp_reduce_sum(float sum) {
#pragma unroll
for(int i = 4; i >= 0; i--){
sum += __shfl_down_sync(0xffffffff, sum, 1<<i);
}
/*
// Equivalent to the following tree reduction implementation:
sum += __shfl_down_sync(0xffffffff, sum, 16);
sum += __shfl_down_sync(0xffffffff, sum, 8);
sum += __shfl_down_sync(0xffffffff, sum, 4);
sum += __shfl_down_sync(0xffffffff, sum, 2);
sum += __shfl_down_sync(0xffffffff, sum, 1);
*/
return sum;
}
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
}
/*
Computes GEMV (group_size = 64).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__ void gemv_kernel_g64(
const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs,
const int IC, const int OC){
const int group_size = 64;
float psum = 0;
const int batch_idx = blockIdx.z;
const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR;
half* outputs = _outputs + batch_idx * OC;
// This is essentially zeros_w.
const int num_groups_packed = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2;
const int weight_w = IC / PACK_FACTOR;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const int zeros_w = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2;
// consistent with input shape
const int sf_w = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2 * PACK_FACTOR;
// if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w, sf_w);
// tile size: 4 OC x 1024 IC per iter
for(int packed_group_idx = 0; packed_group_idx < num_groups_packed / 2; packed_group_idx++){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint64_t packed_zeros = *reinterpret_cast<const uint64_t*>(zeros + oc_idx * zeros_w + packed_group_idx * 2);
uint32_t packed_weights[4];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4));
// load scaling factors
// g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups.
float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]);
float current_zeros = (float)((packed_zeros >> (threadIdx.x / 2 * 4)) & 0xF);
int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
const float4* inputs_ptr = inputs + inputs_ptr_delta;
// multiply 32 weights with 32 inputs
#pragma unroll
for (int ic_0 = 0; ic_0 < 4; ic_0++){
// iterate over different uint32_t packed_weights in this loop
uint32_t current_packed_weight = packed_weights[ic_0];
half packed_inputs[PACK_FACTOR];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
*((float4*)packed_inputs) = *(inputs_ptr + ic_0);
#pragma unroll
for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
// iterate over 8 numbers packed within each uint32_t number
float current_single_weight_fp = (float)(current_packed_weight & 0xF);
float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
current_packed_weight = current_packed_weight >> 4;
}
}
}
}
psum = warp_reduce_sum(psum);
if (threadIdx.x == 0) {
outputs[oc_idx] = __float2half(psum);
}
}
/*
Computes GEMV (group_size = 128).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__ void gemv_kernel_g128(
const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs,
const int IC, const int OC){
const int group_size = 128;
float psum = 0;
const int batch_idx = blockIdx.z;
const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR;
half* outputs = _outputs + batch_idx * OC;
const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR);
const int weight_w = IC / PACK_FACTOR;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR);
// consistent with input shape
const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR;
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w);
// tile size: 4 OC x 1024 IC per iter
for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint32_t packed_zeros = *(zeros + oc_idx * zeros_w + packed_group_idx);
uint32_t packed_weights[4];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4));
// load scaling factors
// g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups.
float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]);
float current_zeros = (float)((packed_zeros >> (threadIdx.x / 4 * 4)) & 0xF);
int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
const float4* inputs_ptr = inputs + inputs_ptr_delta;
// multiply 32 weights with 32 inputs
#pragma unroll
for (int ic_0 = 0; ic_0 < 4; ic_0++){
// iterate over different uint32_t packed_weights in this loop
uint32_t current_packed_weight = packed_weights[ic_0];
half packed_inputs[PACK_FACTOR];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
*((float4*)packed_inputs) = *(inputs_ptr + ic_0);
#pragma unroll
for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
// iterate over 8 numbers packed within each uint32_t number
float current_single_weight_fp = (float)(current_packed_weight & 0xF);
float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
current_packed_weight = current_packed_weight >> 4;
}
}
}
}
psum = warp_reduce_sum(psum);
if (threadIdx.x == 0) {
outputs[oc_idx] = __float2half(psum);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch::Tensor gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int group_size)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
// int kernel_volume = _out_in_map.size(1);
auto in_feats = reinterpret_cast<float4*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr<int>());
auto zeros = reinterpret_cast<uint32_t*>(_zeros.data_ptr<int>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
// auto out_in_map = _out_in_map.data_ptr<int>();
auto options =
torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
// kernel is [OC, IC]
at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
int blockDim_z = num_out_feats;
dim3 num_blocks(1, num_out_channels / 4, num_out_feats);
dim3 num_threads(32, 4);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (group_size == 64)
{
gemv_kernel_g64<<<num_blocks, num_threads, 0, stream>>>(
// pointers
in_feats, kernel, zeros, scaling_factors, out_feats,
// constants
num_in_channels, num_out_channels
);
}
else if (group_size == 128)
{
gemv_kernel_g128<<<num_blocks, num_threads, 0, stream>>>(
// pointers
in_feats, kernel, zeros, scaling_factors, out_feats,
// constants
num_in_channels, num_out_channels
);
}
return _out_feats;
;}
#pragma once
#include <torch/extension.h>
torch::Tensor gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int group_size);
import os
import torch
from pathlib import Path
from setuptools import setup, find_packages
from distutils.sysconfig import get_python_lib
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
AUTOAWQ_KERNELS_VERSION = "0.0.1"
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
if not PYPI_BUILD:
try:
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", torch.version.cuda).split("."))[:3]
AUTOAWQ_KERNELS_VERSION += f"+cu{CUDA_VERSION}"
except Exception as ex:
raise RuntimeError("Your system must have an Nvidia GPU for installing AutoAWQ")
common_setup_kwargs = {
"version": AUTOAWQ_KERNELS_VERSION,
"name": "autoawq_kernels",
"author": "Casper Hansen",
"license": "MIT",
"python_requires": ">=3.8.0",
"description": "AutoAWQ Kernels implements the AWQ kernels.",
"long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"),
"long_description_content_type": "text/markdown",
"url": "https://github.com/casper-hansen/AutoAWQ_kernels",
"keywords": ["awq", "autoawq", "quantization", "transformers"],
"platforms": ["linux", "windows"],
"classifiers": [
"Environment :: GPU :: NVIDIA CUDA :: 11.8",
"Environment :: GPU :: NVIDIA CUDA :: 12",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: C++",
]
}
requirements = [
"torch>=2.0.1",
]
def get_include_dirs():
include_dirs = []
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
this_dir = os.path.dirname(os.path.abspath(__file__))
include_dirs.append(this_dir)
return include_dirs
def get_generator_flag():
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
return generator_flag
def check_dependencies():
if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_compute_capabilities():
# Collect the compute capabilities of all available GPUs.
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cc = major * 10 + minor
if cc < 75:
raise RuntimeError("GPUs with compute capability less than 7.5 are not supported.")
# figure out compute capability
compute_capabilities = {75, 80, 86, 89, 90}
capability_flags = []
for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
return capability_flags
check_dependencies()
include_dirs = get_include_dirs()
generator_flags = get_generator_flag()
arch_flags = get_compute_capabilities()
if os.name == "nt":
include_arch = os.getenv("INCLUDE_ARCH", "1") == "1"
# Relaxed args on Windows
if include_arch:
extra_compile_args={"nvcc": arch_flags}
else:
extra_compile_args={}
else:
extra_compile_args={
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
"nvcc": [
"-O3",
"-std=c++17",
"-DENABLE_BF16",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
] + arch_flags + generator_flags
}
extensions = [
CUDAExtension(
"awq_ext",
[
"awq_cuda/pybind_awq.cpp",
"awq_cuda/quantization/gemm_cuda_gen.cu",
"awq_cuda/layernorm/layernorm.cu",
"awq_cuda/position_embedding/pos_encoding_kernels.cu",
"awq_cuda/quantization/gemv_cuda.cu"
], extra_compile_args=extra_compile_args
)
]
if os.name != "nt":
extensions.append(
CUDAExtension(
"awq_ft_ext",
[
"awq_cuda/pybind_awq_ft.cpp",
"awq_cuda/attention/ft_attention.cpp",
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
], extra_compile_args=extra_compile_args
)
)
additional_setup_kwargs = {
"ext_modules": extensions,
"cmdclass": {'build_ext': BuildExtension}
}
common_setup_kwargs.update(additional_setup_kwargs)
setup(
packages=find_packages(),
install_requires=requirements,
include_dirs=include_dirs,
**common_setup_kwargs
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment