Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
e1884728
Unverified
Commit
e1884728
authored
Sep 15, 2023
by
qwopqwop200
Committed by
GitHub
Sep 15, 2023
Browse files
suppport windows
parent
a5772f67
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
193 additions
and
168 deletions
+193
-168
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+3
-1
awq_cuda/pybind_linux.cpp
awq_cuda/pybind_linux.cpp
+19
-19
awq_cuda/pybind_windows.cpp
awq_cuda/pybind_windows.cpp
+14
-0
setup.py
setup.py
+157
-148
No files found.
awq/modules/fused/attn.py
View file @
e1884728
...
@@ -5,6 +5,8 @@ import torch.nn as nn
...
@@ -5,6 +5,8 @@ import torch.nn as nn
import
awq_inference_engine
import
awq_inference_engine
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
have_single_query_attention
=
hasattr
(
awq_inference_engine
,
'single_query_attention'
)
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
...
@@ -184,7 +186,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -184,7 +186,7 @@ class QuantAttentionFused(nn.Module):
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
if
seqlen
>
1
:
if
seqlen
>
1
and
have_single_query_attention
:
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
...
...
awq_cuda/pybind_linux.cpp
View file @
e1884728
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
#include "attention/ft_attention.h"
#include "layernorm/layernorm.h"
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
}
\ No newline at end of file
awq_cuda/pybind_windows.cpp
0 → 100644
View file @
e1884728
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
}
\ No newline at end of file
setup.py
View file @
e1884728
import
os
import
os
import
torch
import
torch
from
pathlib
import
Path
from
pathlib
import
Path
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
from
distutils.sysconfig
import
get_python_lib
from
distutils.sysconfig
import
get_python_lib
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDA_HOME
,
CUDAExtension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDA_HOME
,
CUDAExtension
os
.
environ
[
"CC"
]
=
"g++"
os
.
environ
[
"CC"
]
=
"g++"
os
.
environ
[
"CXX"
]
=
"g++"
os
.
environ
[
"CXX"
]
=
"g++"
common_setup_kwargs
=
{
common_setup_kwargs
=
{
"version"
:
"0.0.2"
,
"version"
:
"0.0.2"
,
"name"
:
"autoawq"
,
"name"
:
"autoawq"
,
"author"
:
"Casper Hansen"
,
"author"
:
"Casper Hansen"
,
"license"
:
"MIT"
,
"license"
:
"MIT"
,
"python_requires"
:
">=3.8.0"
,
"python_requires"
:
">=3.8.0"
,
"description"
:
"AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference."
,
"description"
:
"AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference."
,
"long_description"
:
(
Path
(
__file__
).
parent
/
"README.md"
).
read_text
(
encoding
=
"UTF-8"
),
"long_description"
:
(
Path
(
__file__
).
parent
/
"README.md"
).
read_text
(
encoding
=
"UTF-8"
),
"long_description_content_type"
:
"text/markdown"
,
"long_description_content_type"
:
"text/markdown"
,
"url"
:
"https://github.com/casper-hansen/AutoAWQ"
,
"url"
:
"https://github.com/casper-hansen/AutoAWQ"
,
"keywords"
:
[
"awq"
,
"autoawq"
,
"quantization"
,
"transformers"
],
"keywords"
:
[
"awq"
,
"autoawq"
,
"quantization"
,
"transformers"
],
"platforms"
:
[
"linux"
,
"windows"
],
"platforms"
:
[
"linux"
,
"windows"
],
"classifiers"
:
[
"classifiers"
:
[
"Environment :: GPU :: NVIDIA CUDA :: 11.8"
,
"Environment :: GPU :: NVIDIA CUDA :: 11.8"
,
"Environment :: GPU :: NVIDIA CUDA :: 12"
,
"Environment :: GPU :: NVIDIA CUDA :: 12"
,
"License :: OSI Approved :: MIT License"
,
"License :: OSI Approved :: MIT License"
,
"Natural Language :: English"
,
"Natural Language :: English"
,
"Programming Language :: Python :: 3.8"
,
"Programming Language :: Python :: 3.8"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.11"
,
"Programming Language :: Python :: 3.11"
,
"Programming Language :: C++"
,
"Programming Language :: C++"
,
]
]
}
}
requirements
=
[
requirements
=
[
"torch>=2.0.0"
,
"torch>=2.0.0"
,
"transformers>=4.32.0"
,
"transformers>=4.32.0"
,
"tokenizers>=0.12.1"
,
"tokenizers>=0.12.1"
,
"accelerate"
,
"accelerate"
,
"sentencepiece"
,
"sentencepiece"
,
"lm_eval"
,
"lm_eval"
,
"texttable"
,
"texttable"
,
"toml"
,
"toml"
,
"attributedict"
,
"attributedict"
,
"protobuf"
,
"protobuf"
,
"torchvision"
,
"torchvision"
,
"tabulate"
"tabulate"
]
]
def
get_include_dirs
():
def
get_include_dirs
():
include_dirs
=
[]
include_dirs
=
[]
conda_cuda_include_dir
=
os
.
path
.
join
(
get_python_lib
(),
"nvidia/cuda_runtime/include"
)
conda_cuda_include_dir
=
os
.
path
.
join
(
get_python_lib
(),
"nvidia/cuda_runtime/include"
)
if
os
.
path
.
isdir
(
conda_cuda_include_dir
):
if
os
.
path
.
isdir
(
conda_cuda_include_dir
):
include_dirs
.
append
(
conda_cuda_include_dir
)
include_dirs
.
append
(
conda_cuda_include_dir
)
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
include_dirs
.
append
(
this_dir
)
include_dirs
.
append
(
this_dir
)
return
include_dirs
return
include_dirs
def
get_generator_flag
():
def
get_generator_flag
():
generator_flag
=
[]
generator_flag
=
[]
torch_dir
=
torch
.
__path__
[
0
]
torch_dir
=
torch
.
__path__
[
0
]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
return
generator_flag
return
generator_flag
def
check_dependencies
():
def
check_dependencies
():
if
CUDA_HOME
is
None
:
if
CUDA_HOME
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Cannot find CUDA_HOME. CUDA must be available to build the package."
)
f
"Cannot find CUDA_HOME. CUDA must be available to build the package."
)
def
get_compute_capabilities
():
def
get_compute_capabilities
():
# Collect the compute capabilities of all available GPUs.
# Collect the compute capabilities of all available GPUs.
compute_capabilities
=
set
()
compute_capabilities
=
set
()
for
i
in
range
(
torch
.
cuda
.
device_count
()):
for
i
in
range
(
torch
.
cuda
.
device_count
()):
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
i
)
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
i
)
if
major
<
8
:
if
major
<
8
:
raise
RuntimeError
(
"GPUs with compute capability less than 8.0 are not supported."
)
raise
RuntimeError
(
"GPUs with compute capability less than 8.0 are not supported."
)
compute_capabilities
.
add
(
major
*
10
+
minor
)
compute_capabilities
.
add
(
major
*
10
+
minor
)
# figure out compute capability
# figure out compute capability
compute_capabilities
=
{
80
,
86
,
89
,
90
}
compute_capabilities
=
{
80
,
86
,
89
,
90
}
capability_flags
=
[]
capability_flags
=
[]
for
cap
in
compute_capabilities
:
for
cap
in
compute_capabilities
:
capability_flags
+=
[
"-gencode"
,
f
"arch=compute_
{
cap
}
,code=sm_
{
cap
}
"
]
capability_flags
+=
[
"-gencode"
,
f
"arch=compute_
{
cap
}
,code=sm_
{
cap
}
"
]
return
capability_flags
return
capability_flags
check_dependencies
()
check_dependencies
()
include_dirs
=
get_include_dirs
()
include_dirs
=
get_include_dirs
()
generator_flags
=
get_generator_flag
()
generator_flags
=
get_generator_flag
()
arch_flags
=
get_compute_capabilities
()
arch_flags
=
get_compute_capabilities
()
if
os
.
name
==
"nt"
:
if
os
.
name
==
"nt"
:
# Relaxed args on Windows
# Relaxed args on Windows
extra_compile_args
=
{
extensions
=
[
"nvcc"
:
arch_flags
CUDAExtension
(
}
"awq_inference_engine"
,
else
:
[
extra_compile_args
=
{
"awq_cuda/pybind_windows.cpp"
,
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"nvcc"
:
[
"awq_cuda/layernorm/layernorm.cu"
,
"-O3"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"-std=c++17"
,
"awq_cuda/quantization/gemv_cuda.cu"
,
"-DENABLE_BF16"
,
]
"-U__CUDA_NO_HALF_OPERATORS__"
,
)
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
]
"-U__CUDA_NO_BFLOAT16_OPERATORS__"
,
else
:
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
extra_compile_args
=
{
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
"nvcc"
:
[
"--expt-relaxed-constexpr"
,
"-O3"
,
"--expt-extended-lambda"
,
"-std=c++17"
,
"--use_fast_math"
,
"-DENABLE_BF16"
,
]
+
arch_flags
+
generator_flags
"-U__CUDA_NO_HALF_OPERATORS__"
,
}
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT16_OPERATORS__"
,
extensions
=
[
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
CUDAExtension
(
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"awq_inference_engine"
,
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
[
"--expt-relaxed-constexpr"
,
"awq_cuda/pybind.cpp"
,
"--expt-extended-lambda"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"--use_fast_math"
,
"awq_cuda/layernorm/layernorm.cu"
,
]
+
arch_flags
+
generator_flags
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
}
"awq_cuda/quantization/gemv_cuda.cu"
,
"awq_cuda/attention/ft_attention.cpp"
,
extensions
=
[
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
CUDAExtension
(
],
extra_compile_args
=
extra_compile_args
"awq_inference_engine"
,
)
[
]
"awq_cuda/pybind_linux.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
additional_setup_kwargs
=
{
"awq_cuda/layernorm/layernorm.cu"
,
"ext_modules"
:
extensions
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"cmdclass"
:
{
'build_ext'
:
BuildExtension
}
"awq_cuda/quantization/gemv_cuda.cu"
,
}
"awq_cuda/attention/ft_attention.cpp"
,
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
common_setup_kwargs
.
update
(
additional_setup_kwargs
)
],
extra_compile_args
=
extra_compile_args
)
setup
(
]
packages
=
find_packages
(),
install_requires
=
requirements
,
additional_setup_kwargs
=
{
include_dirs
=
include_dirs
,
"ext_modules"
:
extensions
,
**
common_setup_kwargs
"cmdclass"
:
{
'build_ext'
:
BuildExtension
}
}
common_setup_kwargs
.
update
(
additional_setup_kwargs
)
setup
(
packages
=
find_packages
(),
install_requires
=
requirements
,
include_dirs
=
include_dirs
,
**
common_setup_kwargs
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment