Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ_kernels
Commits
e90433a0
Commit
e90433a0
authored
Dec 22, 2023
by
Casper
Browse files
Initial commit
parent
5440c0aa
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
418 additions
and
0 deletions
+418
-0
awq_cuda/quantization/gemv_cuda.cu
awq_cuda/quantization/gemv_cuda.cu
+249
-0
awq_cuda/quantization/gemv_cuda.h
awq_cuda/quantization/gemv_cuda.h
+9
-0
setup.py
setup.py
+160
-0
No files found.
awq_cuda/quantization/gemv_cuda.cu
0 → 100644
View file @
e90433a0
// Inspired by https://github.com/ankan-ban/llama_cu_awq
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemv_cuda.h"
#define VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8
#define PACK_FACTOR 8
#define WARP_SIZE 32
// Reduce sum within the warp using the tree reduction algorithm.
__device__
__forceinline__
float
warp_reduce_sum
(
float
sum
)
{
#pragma unroll
for
(
int
i
=
4
;
i
>=
0
;
i
--
){
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
1
<<
i
);
}
/*
// Equivalent to the following tree reduction implementation:
sum += __shfl_down_sync(0xffffffff, sum, 16);
sum += __shfl_down_sync(0xffffffff, sum, 8);
sum += __shfl_down_sync(0xffffffff, sum, 4);
sum += __shfl_down_sync(0xffffffff, sum, 2);
sum += __shfl_down_sync(0xffffffff, sum, 1);
*/
return
sum
;
}
__device__
__forceinline__
int
make_divisible
(
int
c
,
int
divisor
){
return
(
c
+
divisor
-
1
)
/
divisor
;
}
/*
Computes GEMV (group_size = 64).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__
void
gemv_kernel_g64
(
const
float4
*
_inputs
,
const
uint32_t
*
weight
,
const
uint32_t
*
zeros
,
const
half
*
scaling_factors
,
half
*
_outputs
,
const
int
IC
,
const
int
OC
){
const
int
group_size
=
64
;
float
psum
=
0
;
const
int
batch_idx
=
blockIdx
.
z
;
const
int
oc_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
float4
*
inputs
=
_inputs
+
batch_idx
*
IC
/
PACK_FACTOR
;
half
*
outputs
=
_outputs
+
batch_idx
*
OC
;
// This is essentially zeros_w.
const
int
num_groups_packed
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
;
const
int
weight_w
=
IC
/
PACK_FACTOR
;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const
int
zeros_w
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
;
// consistent with input shape
const
int
sf_w
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
*
PACK_FACTOR
;
// if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w, sf_w);
// tile size: 4 OC x 1024 IC per iter
for
(
int
packed_group_idx
=
0
;
packed_group_idx
<
num_groups_packed
/
2
;
packed_group_idx
++
){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint64_t
packed_zeros
=
*
reinterpret_cast
<
const
uint64_t
*>
(
zeros
+
oc_idx
*
zeros_w
+
packed_group_idx
*
2
);
uint32_t
packed_weights
[
4
];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*
((
float4
*
)(
packed_weights
))
=
*
((
float4
*
)(
weight
+
oc_idx
*
weight_w
+
packed_group_idx
*
(
WARP_SIZE
*
4
)
+
threadIdx
.
x
*
4
));
// load scaling factors
// g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups.
float
scaling_factor
=
__half2float
(
scaling_factors
[
oc_idx
*
sf_w
+
packed_group_idx
*
16
+
(
threadIdx
.
x
/
2
)]);
float
current_zeros
=
(
float
)((
packed_zeros
>>
(
threadIdx
.
x
/
2
*
4
))
&
0xF
);
int
inputs_ptr_delta
=
packed_group_idx
*
WARP_SIZE
*
4
+
threadIdx
.
x
*
4
;
const
float4
*
inputs_ptr
=
inputs
+
inputs_ptr_delta
;
// multiply 32 weights with 32 inputs
#pragma unroll
for
(
int
ic_0
=
0
;
ic_0
<
4
;
ic_0
++
){
// iterate over different uint32_t packed_weights in this loop
uint32_t
current_packed_weight
=
packed_weights
[
ic_0
];
half
packed_inputs
[
PACK_FACTOR
];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if
(
inputs_ptr_delta
+
ic_0
<
IC
/
PACK_FACTOR
)
{
*
((
float4
*
)
packed_inputs
)
=
*
(
inputs_ptr
+
ic_0
);
#pragma unroll
for
(
int
ic_1
=
0
;
ic_1
<
PACK_FACTOR
;
ic_1
++
){
// iterate over 8 numbers packed within each uint32_t number
float
current_single_weight_fp
=
(
float
)(
current_packed_weight
&
0xF
);
float
dequantized_weight
=
scaling_factor
*
(
current_single_weight_fp
-
current_zeros
);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum
+=
dequantized_weight
*
__half2float
(
packed_inputs
[
ic_1
]);
current_packed_weight
=
current_packed_weight
>>
4
;
}
}
}
}
psum
=
warp_reduce_sum
(
psum
);
if
(
threadIdx
.
x
==
0
)
{
outputs
[
oc_idx
]
=
__float2half
(
psum
);
}
}
/*
Computes GEMV (group_size = 128).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__
void
gemv_kernel_g128
(
const
float4
*
_inputs
,
const
uint32_t
*
weight
,
const
uint32_t
*
zeros
,
const
half
*
scaling_factors
,
half
*
_outputs
,
const
int
IC
,
const
int
OC
){
const
int
group_size
=
128
;
float
psum
=
0
;
const
int
batch_idx
=
blockIdx
.
z
;
const
int
oc_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
float4
*
inputs
=
_inputs
+
batch_idx
*
IC
/
PACK_FACTOR
;
half
*
outputs
=
_outputs
+
batch_idx
*
OC
;
const
int
num_groups_packed
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
);
const
int
weight_w
=
IC
/
PACK_FACTOR
;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const
int
zeros_w
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
);
// consistent with input shape
const
int
sf_w
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
)
*
PACK_FACTOR
;
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w);
// tile size: 4 OC x 1024 IC per iter
for
(
int
packed_group_idx
=
0
;
packed_group_idx
<
num_groups_packed
;
packed_group_idx
++
){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint32_t
packed_zeros
=
*
(
zeros
+
oc_idx
*
zeros_w
+
packed_group_idx
);
uint32_t
packed_weights
[
4
];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*
((
float4
*
)(
packed_weights
))
=
*
((
float4
*
)(
weight
+
oc_idx
*
weight_w
+
packed_group_idx
*
(
WARP_SIZE
*
4
)
+
threadIdx
.
x
*
4
));
// load scaling factors
// g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups.
float
scaling_factor
=
__half2float
(
scaling_factors
[
oc_idx
*
sf_w
+
packed_group_idx
*
8
+
(
threadIdx
.
x
/
4
)]);
float
current_zeros
=
(
float
)((
packed_zeros
>>
(
threadIdx
.
x
/
4
*
4
))
&
0xF
);
int
inputs_ptr_delta
=
packed_group_idx
*
WARP_SIZE
*
4
+
threadIdx
.
x
*
4
;
const
float4
*
inputs_ptr
=
inputs
+
inputs_ptr_delta
;
// multiply 32 weights with 32 inputs
#pragma unroll
for
(
int
ic_0
=
0
;
ic_0
<
4
;
ic_0
++
){
// iterate over different uint32_t packed_weights in this loop
uint32_t
current_packed_weight
=
packed_weights
[
ic_0
];
half
packed_inputs
[
PACK_FACTOR
];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if
(
inputs_ptr_delta
+
ic_0
<
IC
/
PACK_FACTOR
)
{
*
((
float4
*
)
packed_inputs
)
=
*
(
inputs_ptr
+
ic_0
);
#pragma unroll
for
(
int
ic_1
=
0
;
ic_1
<
PACK_FACTOR
;
ic_1
++
){
// iterate over 8 numbers packed within each uint32_t number
float
current_single_weight_fp
=
(
float
)(
current_packed_weight
&
0xF
);
float
dequantized_weight
=
scaling_factor
*
(
current_single_weight_fp
-
current_zeros
);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum
+=
dequantized_weight
*
__half2float
(
packed_inputs
[
ic_1
]);
current_packed_weight
=
current_packed_weight
>>
4
;
}
}
}
}
psum
=
warp_reduce_sum
(
psum
);
if
(
threadIdx
.
x
==
0
)
{
outputs
[
oc_idx
]
=
__float2half
(
psum
);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch
::
Tensor
gemv_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
)
{
int
num_in_feats
=
_in_feats
.
size
(
0
);
int
num_in_channels
=
_in_feats
.
size
(
1
);
// int kernel_volume = _out_in_map.size(1);
auto
in_feats
=
reinterpret_cast
<
float4
*>
(
_in_feats
.
data_ptr
<
at
::
Half
>
());
auto
kernel
=
reinterpret_cast
<
uint32_t
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
zeros
=
reinterpret_cast
<
uint32_t
*>
(
_zeros
.
data_ptr
<
int
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
// auto out_in_map = _out_in_map.data_ptr<int>();
auto
options
=
torch
::
TensorOptions
().
dtype
(
_in_feats
.
dtype
()).
device
(
_in_feats
.
device
());
// kernel is [OC, IC]
at
::
Tensor
_out_feats
=
torch
::
empty
({
num_in_feats
,
_kernel
.
size
(
0
)},
options
);
int
num_out_feats
=
_out_feats
.
size
(
-
2
);
int
num_out_channels
=
_out_feats
.
size
(
-
1
);
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
int
blockDim_z
=
num_out_feats
;
dim3
num_blocks
(
1
,
num_out_channels
/
4
,
num_out_feats
);
dim3
num_threads
(
32
,
4
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
group_size
==
64
)
{
gemv_kernel_g64
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
num_in_channels
,
num_out_channels
);
}
else
if
(
group_size
==
128
)
{
gemv_kernel_g128
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
num_in_channels
,
num_out_channels
);
}
return
_out_feats
;
;}
awq_cuda/quantization/gemv_cuda.h
0 → 100644
View file @
e90433a0
#pragma once
#include <torch/extension.h>
torch
::
Tensor
gemv_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
);
setup.py
0 → 100644
View file @
e90433a0
import
os
import
torch
from
pathlib
import
Path
from
setuptools
import
setup
,
find_packages
from
distutils.sysconfig
import
get_python_lib
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDA_HOME
,
CUDAExtension
os
.
environ
[
"CC"
]
=
"g++"
os
.
environ
[
"CXX"
]
=
"g++"
AUTOAWQ_KERNELS_VERSION
=
"0.0.1"
PYPI_BUILD
=
os
.
getenv
(
"PYPI_BUILD"
,
"0"
)
==
"1"
if
not
PYPI_BUILD
:
try
:
CUDA_VERSION
=
""
.
join
(
os
.
environ
.
get
(
"CUDA_VERSION"
,
torch
.
version
.
cuda
).
split
(
"."
))[:
3
]
AUTOAWQ_KERNELS_VERSION
+=
f
"+cu
{
CUDA_VERSION
}
"
except
Exception
as
ex
:
raise
RuntimeError
(
"Your system must have an Nvidia GPU for installing AutoAWQ"
)
common_setup_kwargs
=
{
"version"
:
AUTOAWQ_KERNELS_VERSION
,
"name"
:
"autoawq_kernels"
,
"author"
:
"Casper Hansen"
,
"license"
:
"MIT"
,
"python_requires"
:
">=3.8.0"
,
"description"
:
"AutoAWQ Kernels implements the AWQ kernels."
,
"long_description"
:
(
Path
(
__file__
).
parent
/
"README.md"
).
read_text
(
encoding
=
"UTF-8"
),
"long_description_content_type"
:
"text/markdown"
,
"url"
:
"https://github.com/casper-hansen/AutoAWQ_kernels"
,
"keywords"
:
[
"awq"
,
"autoawq"
,
"quantization"
,
"transformers"
],
"platforms"
:
[
"linux"
,
"windows"
],
"classifiers"
:
[
"Environment :: GPU :: NVIDIA CUDA :: 11.8"
,
"Environment :: GPU :: NVIDIA CUDA :: 12"
,
"License :: OSI Approved :: MIT License"
,
"Natural Language :: English"
,
"Programming Language :: Python :: 3.8"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.11"
,
"Programming Language :: C++"
,
]
}
requirements
=
[
"torch>=2.0.1"
,
]
def
get_include_dirs
():
include_dirs
=
[]
conda_cuda_include_dir
=
os
.
path
.
join
(
get_python_lib
(),
"nvidia/cuda_runtime/include"
)
if
os
.
path
.
isdir
(
conda_cuda_include_dir
):
include_dirs
.
append
(
conda_cuda_include_dir
)
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
include_dirs
.
append
(
this_dir
)
return
include_dirs
def
get_generator_flag
():
generator_flag
=
[]
torch_dir
=
torch
.
__path__
[
0
]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
return
generator_flag
def
check_dependencies
():
if
CUDA_HOME
is
None
:
raise
RuntimeError
(
f
"Cannot find CUDA_HOME. CUDA must be available to build the package."
)
def
get_compute_capabilities
():
# Collect the compute capabilities of all available GPUs.
for
i
in
range
(
torch
.
cuda
.
device_count
()):
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
i
)
cc
=
major
*
10
+
minor
if
cc
<
75
:
raise
RuntimeError
(
"GPUs with compute capability less than 7.5 are not supported."
)
# figure out compute capability
compute_capabilities
=
{
75
,
80
,
86
,
89
,
90
}
capability_flags
=
[]
for
cap
in
compute_capabilities
:
capability_flags
+=
[
"-gencode"
,
f
"arch=compute_
{
cap
}
,code=sm_
{
cap
}
"
]
return
capability_flags
check_dependencies
()
include_dirs
=
get_include_dirs
()
generator_flags
=
get_generator_flag
()
arch_flags
=
get_compute_capabilities
()
if
os
.
name
==
"nt"
:
include_arch
=
os
.
getenv
(
"INCLUDE_ARCH"
,
"1"
)
==
"1"
# Relaxed args on Windows
if
include_arch
:
extra_compile_args
=
{
"nvcc"
:
arch_flags
}
else
:
extra_compile_args
=
{}
else
:
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
,
"-DENABLE_BF16"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT16_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--use_fast_math"
,
]
+
arch_flags
+
generator_flags
}
extensions
=
[
CUDAExtension
(
"awq_ext"
,
[
"awq_cuda/pybind_awq.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
],
extra_compile_args
=
extra_compile_args
)
]
if
os
.
name
!=
"nt"
:
extensions
.
append
(
CUDAExtension
(
"awq_ft_ext"
,
[
"awq_cuda/pybind_awq_ft.cpp"
,
"awq_cuda/attention/ft_attention.cpp"
,
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
],
extra_compile_args
=
extra_compile_args
)
)
additional_setup_kwargs
=
{
"ext_modules"
:
extensions
,
"cmdclass"
:
{
'build_ext'
:
BuildExtension
}
}
common_setup_kwargs
.
update
(
additional_setup_kwargs
)
setup
(
packages
=
find_packages
(),
install_requires
=
requirements
,
include_dirs
=
include_dirs
,
**
common_setup_kwargs
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment